client/internal/imds/imds.go (83 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package imds
//go:generate ../../bin/mockgen -copyright_file=../../../hack/copyright_header.txt -destination=./mocks/mock_imds.go -package=mocks github.com/Azure/aks-secure-tls-bootstrap/client/internal/imds Client
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/datamodel"
internalhttp "github.com/Azure/aks-secure-tls-bootstrap/client/internal/http"
"go.uber.org/zap"
)
type Client interface {
GetInstanceData(ctx context.Context) (*datamodel.VMSSInstanceData, error)
GetAttestedData(ctx context.Context, nonce string) (*datamodel.VMSSAttestedData, error)
}
type client struct {
baseURL string
httpClient *http.Client
logger *zap.Logger
}
var _ Client = (*client)(nil)
func NewClient(logger *zap.Logger) Client {
return &client{
baseURL: imdsURL,
httpClient: internalhttp.NewClient(logger),
logger: logger,
}
}
func (c *client) GetInstanceData(ctx context.Context) (*datamodel.VMSSInstanceData, error) {
url := fmt.Sprintf("%s/%s", c.baseURL, instanceDataEndpoint)
c.logger.Info("calling IMDS instance data endpoint", zap.String("url", url))
params := getCommonParameters()
params[formatParameterKey] = "json"
var data datamodel.VMSSInstanceData
if err := c.callIMDS(ctx, url, params, &data); err != nil {
return nil, fmt.Errorf("failed to retrieve IMDS instance data: %w", err)
}
return &data, nil
}
func (c *client) GetAttestedData(ctx context.Context, nonce string) (*datamodel.VMSSAttestedData, error) {
url := fmt.Sprintf("%s/%s", c.baseURL, attestedDataEndpoint)
c.logger.Info("calling IMDS attested data endpoint", zap.String("url", url))
params := getCommonParameters()
params[formatParameterKey] = "json"
params[nonceParameterKey] = nonce
var data datamodel.VMSSAttestedData
if err := c.callIMDS(ctx, url, params, &data); err != nil {
return nil, fmt.Errorf("failed to retrieve IMDS attested data: %w", err)
}
return &data, nil
}
func (c *client) callIMDS(ctx context.Context, url string, queryParameters map[string]string, responseObject interface{}) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("failed to construct new HTTP request to IMDS: %w", err)
}
req.Header.Add(metadataHeaderKey, "True")
query := req.URL.Query()
for key := range queryParameters {
query.Add(key, queryParameters[key])
}
req.URL.RawQuery = query.Encode()
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to do HTTP request to IMDS: %w", err)
}
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read IMDS response body: %w", err)
}
if err := json.Unmarshal(body, responseObject); err != nil {
return fmt.Errorf("failed to unmarshal IMDS data: %w", err)
}
return nil
}
func getCommonParameters() map[string]string {
return map[string]string{
apiVersionParameterKey: apiVersion,
}
}