e2etest/newe2e_arm_client.go (155 lines of code) (raw):
package e2etest
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/Azure/azure-storage-azcopy/v10/common"
"io"
"net/http"
"net/url"
"reflect"
)
type ARMSubject interface {
Token() AccessToken
Client() *ARMClient
ManagementURI() url.URL
}
type ARMRequestPreparer interface {
PrepareRequest(settings *ARMRequestSettings)
}
func CombineQuery(a, b url.Values) url.Values {
out := make(url.Values)
for k, v := range a {
out[k] = append(out[k], v...)
}
for k, v := range b {
out[k] = append(out[k], v...)
}
return out
}
// Ensure all types match interfaces
func init() {
_ = []ARMSubject{&ARMClient{}, &ARMSubscription{}, &ARMResourceGroup{}}
}
// ARMUnimplementedStruct is used to fill in the blanks in types when implementation doesn't seem like a good use of time.
// It can be explored with "Get", if you know precisely what you want.
type ARMUnimplementedStruct json.RawMessage
func (s ARMUnimplementedStruct) Get(Key []string, out interface{}) error {
if reflect.TypeOf(out).Kind() != reflect.Pointer {
return errors.New("")
}
object := s
for len(Key) > 0 {
dict := make(map[string]json.RawMessage)
err := json.Unmarshal(object, &dict)
if err != nil {
return err
}
object = ARMUnimplementedStruct(dict[Key[0]])
}
return json.Unmarshal(object, out)
}
type ARMClient struct {
OAuth AccessToken
HttpClient *http.Client
}
func (c *ARMClient) Client() *ARMClient {
return c
}
func (c *ARMClient) getHTTPClient() *http.Client {
if c.HttpClient != nil {
return c.HttpClient
}
return http.DefaultClient
}
func (c *ARMClient) Token() AccessToken {
return c.OAuth
}
func (c *ARMClient) ManagementURI() url.URL {
uri, err := url.Parse("https://management.azure.com/")
common.PanicIfErr(err) // should never happen
return *uri
}
type ARMRequestSettings struct { // All values will be added to the request
Method string
PathExtension string
Query url.Values
Headers http.Header
Body interface{}
}
func (s *ARMRequestSettings) CreateRequest(baseURI url.URL) (*http.Request, error) {
query := baseURI.RawQuery
if len(query) > 0 {
query += "&"
}
query += s.Query.Encode()
baseURI.RawQuery = query
var body io.ReadSeeker
if s.Body != nil {
buf, err := json.Marshal(s.Body)
if err != nil {
return nil, fmt.Errorf("failed to marshal body: %w", err)
}
body = bytes.NewReader(buf)
}
newReq, err := http.NewRequest(s.Method, baseURI.String(), body)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
newReq.Header = s.Headers
if s.PathExtension != "" {
newReq.URL = newReq.URL.JoinPath(s.PathExtension)
}
return newReq, nil
}
// PerformRequest will deserialize to target (which assumes the target is a pointer)
// If an LRO is required, an *ARMAsyncResponse will be returned. Otherwise, both armResp and err will be nil, and target will be written to.
func PerformRequest[Props any](subject ARMSubject, reqSettings ARMRequestSettings, target *Props) (armResp *ARMAsyncResponse[Props], err error) {
c := subject.Client()
baseURI := subject.ManagementURI()
client := c.getHTTPClient()
if prep, ok := subject.(ARMRequestPreparer); ok {
prep.PrepareRequest(&reqSettings)
}
r, err := reqSettings.CreateRequest(baseURI)
if err != nil {
return nil, fmt.Errorf("failed to prepare request: %w", err)
}
oAuthToken, err := subject.Token().FreshToken()
r.Header = make(http.Header)
r.Header["Authorization"] = []string{"Bearer " + oAuthToken}
r.Header["Content-Type"] = []string{"application/json; charset=utf-8"}
r.Header["Accept"] = []string{"application/json; charset=utf-8"}
resp, err := client.Do(r)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
switch resp.StatusCode {
case 202: // LRO pattern; grab Azure-AsyncOperation and resolve it.
newTarget := resp.Header.Get("Azure-Asyncoperation")
if newTarget == "" {
newTarget = resp.Header.Get("Location")
}
if newTarget != "" {
return ResolveAzureAsyncOperation(c.OAuth, newTarget, target)
} else if resp.Header.Get("Content-Length") == "0" {
return nil, fmt.Errorf("failed to handle async operation: no response data, Azure-Asyncoperation and Location are not found")
}
// If we don't have an asyncop to check against, pull the body
fallthrough
case 200, 201: // immediate response
var buf []byte // Read the body
buf, err = io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body (resp code 200): %w", err)
}
if len(buf) != 0 && target != nil {
err = json.Unmarshal(buf, target)
if err != nil {
return nil, fmt.Errorf("failed to parse response body: %w", err)
}
}
return nil, nil
default:
rBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body (resp code %d): %w", resp.StatusCode, err)
}
return nil, fmt.Errorf("failed to get access (resp code %d): %s", resp.StatusCode, string(rBody))
}
}