internal/metadata/metadata.go (204 lines of code) (raw):

// Copyright 2017 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package metadata provides a client for communication with Metadata Server. package metadata import ( "context" "fmt" "io" "net/http" "net/url" "slices" "strings" "time" "github.com/GoogleCloudPlatform/galog" "github.com/GoogleCloudPlatform/google-guest-agent/internal/retry" ) const ( // defaultMetadataURL is the default endpoint used to connect to Metadata server. defaultMetadataURL = "http://169.254.169.254/computeMetadata/v1/" // defaultEtag is the default etag used when none is set. defaultEtag = "NONE" // defaultHangTimeout is the timeout parameter passed to metadata as the hang timeout. defaultHangTimeout = 60 // defaultClientTimeout sets the http.Client time out, the delta of 10s between the // defaultHangTimeout and client timeout should be enough to avoid canceling the context // before headers and body are read. defaultClientTimeout = 70 ) var ( // we backoff until 10s backoffDuration = 100 * time.Millisecond backoffAttempts = 100 ) // MDSClientInterface is the minimum required Metadata Server interface for Guest Agent. type MDSClientInterface interface { // Get returns the metadata descriptor which includes all details from MDS. Get(context.Context) (*Descriptor, error) // GetKey gets a specific metadata key. GetKey(context.Context, string, map[string]string) (string, error) // GetKeyRecursive gets a specific metadata key recursively (key and all its sub children). GetKeyRecursive(context.Context, string) (string, error) // Watch waits for any change on MDS and returns the metadata descriptor which includes all details from MDS. Watch(context.Context) (*Descriptor, error) // WriteGuestAttributes writes the key and value to guest attributes in MDS. WriteGuestAttributes(context.Context, string, string) error } // requestConfig is used internally to configure an http request given its context. type requestConfig struct { baseURL string hang bool recursive bool jsonOutput bool timeout int headers map[string]string } // Client defines the public interface between the core guest agent and // the metadata layer. type Client struct { metadataURL string etag string httpClient *http.Client } // New allocates and configures a new Client instance. func New() *Client { return &Client{ metadataURL: defaultMetadataURL, etag: defaultEtag, httpClient: &http.Client{ Timeout: defaultClientTimeout * time.Second, }, } } func (c *Client) updateEtag(resp *http.Response) bool { oldEtag := c.etag c.etag = resp.Header.Get("etag") if c.etag == "" { c.etag = defaultEtag } return c.etag != oldEtag } // MDSReqError represents custom error produced by HTTP requests made on MDS. It captures // error and status code for inspecting. type MDSReqError struct { status int err error } // Error implements method defined on error interface to transform custom type into error. func (m *MDSReqError) Error() string { return fmt.Sprintf("request failed with status code: [%d], error: [%v]", m.status, m.err) } // shouldRetry method checks if MDSReqError is temporary and retriable or not. func shouldRetry(err error) bool { e, ok := err.(*MDSReqError) if !ok { // Unknown error retry. return true } // Known non-retriable status codes. codes := []int{404} return !slices.Contains(codes, e.status) } func (c *Client) retry(ctx context.Context, cfg requestConfig) (string, error) { policy := retry.Policy{MaxAttempts: backoffAttempts, Jitter: backoffDuration, BackoffFactor: 1, ShouldRetry: shouldRetry} fn := func() (string, error) { resp, err := c.do(ctx, cfg) if err != nil { statusCode := -1 if resp != nil { statusCode = resp.StatusCode } return "", &MDSReqError{statusCode, err} } defer resp.Body.Close() md, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("failed to read metadata server response bytes: %w", err) } return string(md), nil } return retry.RunWithResponse(ctx, policy, fn) } // GetKey gets a specific metadata key. func (c *Client) GetKey(ctx context.Context, key string, headers map[string]string) (string, error) { reqURL, err := url.JoinPath(c.metadataURL, key) if err != nil { return "", fmt.Errorf("failed to form metadata url: %w", err) } cfg := requestConfig{ baseURL: reqURL, headers: headers, } return c.retry(ctx, cfg) } // GetKeyRecursive gets a specific metadata key recursively and returns JSON output. func (c *Client) GetKeyRecursive(ctx context.Context, key string) (string, error) { reqURL, err := url.JoinPath(c.metadataURL, key) if err != nil { return "", fmt.Errorf("failed to form metadata url: %w", err) } cfg := requestConfig{ baseURL: reqURL, jsonOutput: true, recursive: true, } return c.retry(ctx, cfg) } // Watch runs a long poll on metadata server. func (c *Client) Watch(ctx context.Context) (*Descriptor, error) { return c.get(ctx, true) } // Get does a metadata call, if hang is set to true then it will do a longpoll. func (c *Client) Get(ctx context.Context) (*Descriptor, error) { return c.get(ctx, false) } func (c *Client) get(ctx context.Context, hang bool) (*Descriptor, error) { cfg := requestConfig{ baseURL: c.metadataURL, timeout: defaultHangTimeout, recursive: true, jsonOutput: true, } if hang { cfg.hang = true } resp, err := c.retry(ctx, cfg) if err != nil { return nil, err } return UnmarshalDescriptor(resp) } // WriteGuestAttributes does a put call to mds changing a guest attribute value. func (c *Client) WriteGuestAttributes(ctx context.Context, key, value string) error { galog.Debugf("write guest attribute %q", key) finalURL, err := url.JoinPath(c.metadataURL, "instance/guest-attributes/", key) if err != nil { return fmt.Errorf("failed to form metadata url: %w", err) } galog.Debugf("Requesting(PUT) MDS URL: %s", finalURL) req, err := http.NewRequestWithContext(ctx, "PUT", finalURL, strings.NewReader(value)) if err != nil { return err } req.Header.Add("Metadata-Flavor", "Google") _, err = c.httpClient.Do(req) return err } func (c *Client) do(ctx context.Context, cfg requestConfig) (*http.Response, error) { finalURL, err := url.Parse(cfg.baseURL) if err != nil { return nil, fmt.Errorf("failed to parse url: %w", err) } values := finalURL.Query() if cfg.hang { values.Add("wait_for_change", "true") values.Add("last_etag", c.etag) } if cfg.timeout > 0 { values.Add("timeout_sec", fmt.Sprintf("%d", cfg.timeout)) } if cfg.recursive { values.Add("recursive", "true") } if cfg.jsonOutput { values.Add("alt", "json") } finalURL.RawQuery = values.Encode() galog.Debugf("Requesting(GET) MDS URL: %s", finalURL.String()) req, err := http.NewRequestWithContext(ctx, "GET", finalURL.String(), nil) if err != nil { return nil, err } req.Header.Add("Metadata-Flavor", "Google") for k, v := range cfg.headers { req.Header.Add(k, v) } resp, err := c.httpClient.Do(req) // If we are canceling httpClient will also wrap the context's error so // check first the context. if ctx.Err() != nil { return resp, ctx.Err() } if err != nil { return resp, fmt.Errorf("error connecting to metadata server: %w", err) } if resp == nil { return nil, fmt.Errorf("got nil response from metadata server") } if resp.StatusCode != http.StatusOK { defer resp.Body.Close() // Ignore read error as we are returning original error and wrapping MDS error code. r, _ := io.ReadAll(resp.Body) return resp, fmt.Errorf("invalid response from metadata server, status code: %d, reason: %s", resp.StatusCode, string(r)) } if cfg.hang { c.updateEtag(resp) } return resp, nil }