metadata/metadata.go (398 lines of code) (raw):

// Copyright 2017 Google LLC // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // https://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package metadata import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/url" "slices" "strconv" "strings" "time" "github.com/GoogleCloudPlatform/guest-agent/retry" "github.com/GoogleCloudPlatform/guest-logging-go/logger" ) const ( defaultMetadataURL = "http://169.254.169.254/computeMetadata/v1/" defaultEtag = "NONE" // defaultHangtimeout is the timeout parameter passed to metadata as the hang timeout. defaultHangTimeout = 60 // defaultClientTimeout sets the http.Client time out, the delta of 10s between the // defaultHangTimeout and client timeout should be enough to avoid canceling the context // before headers and body are read. defaultClientTimeout = 70 ) var ( // we backoff until 10s backoffDuration = 100 * time.Millisecond backoffAttempts = 100 ) // MDSClientInterface is the minimum required Metadata Server interface for Guest Agent. type MDSClientInterface interface { Get(context.Context) (*Descriptor, error) GetKey(context.Context, string, map[string]string) (string, error) GetKeyRecursive(context.Context, string) (string, error) Watch(context.Context) (*Descriptor, error) WriteGuestAttributes(context.Context, string, string) error } // requestConfig is used internally to configure an http request given its context. type requestConfig struct { baseURL string hang bool recursive bool jsonOutput bool timeout int headers map[string]string } // Client defines the public interface between the core guest agent and // the metadata layer. type Client struct { metadataURL string etag string httpClient *http.Client } // New allocates and configures a new Client instance. func New() *Client { return &Client{ metadataURL: defaultMetadataURL, etag: defaultEtag, httpClient: &http.Client{ Timeout: defaultClientTimeout * time.Second, }, } } // Descriptor wraps/holds all the metadata keys, the structure reflects the json // descriptor returned with metadata call with alt=jason. type Descriptor struct { Instance Instance Project Project } // UnmarshalJSON unmarshals b into Descritor. func (m *Descriptor) UnmarshalJSON(b []byte) error { // We can't unmarshal into metadata directly as it would create an infinite loop. type temp Descriptor var t temp err := json.Unmarshal(b, &t) if err == nil { *m = Descriptor(t) return nil } // If this is a syntax error return a useful error. sErr, ok := err.(*json.SyntaxError) if !ok { return err } // Byte number where the error line starts. start := bytes.LastIndex(b[:sErr.Offset], []byte("\n")) + 1 // Assume end byte of error line is EOF unless this isn't the last line. end := len(b) if i := bytes.Index(b[start:], []byte("\n")); i >= 0 { end = start + i } // Position of error in line (where to place the '^'). pos := int(sErr.Offset) - start if pos != 0 { pos = pos - 1 } return fmt.Errorf("JSON syntax error: %s \n%s\n%s^", err, b[start:end], strings.Repeat(" ", pos)) } type virtualClock struct { DriftToken int `json:"drift-token"` } // Instance describes the metadata's instance attributes/keys. type Instance struct { // ID is the instance ID. ID json.Number // MachineType represents the instance's machine type. MachineType string // Attributes are the instance's attributes. Attributes Attributes // NetworkInterfaces contains all configured regular network interfaces (primary and secondary). NetworkInterfaces []NetworkInterfaces // VlanNetworkInterfaces contains all the vLAN network interfaces. VlanNetworkInterfaces map[int]map[int]VlanInterface // VirtualClock contains the drift-token attribute. VirtualClock virtualClock } // NetworkInterfaces describes the instances network interfaces configurations. type NetworkInterfaces struct { ForwardedIps []string ForwardedIpv6s []string TargetInstanceIps []string IPAliases []string Mac string DHCPv6Refresh string MTU int } // VlanInterface describes the instances vlan network interfaces configurations. type VlanInterface struct { // Mac is the vLAN interface's mac address. Mac string // ParentInterface is the mds reference of the parent/physical interface i.e.: // /computeMetadata/v1/instance/network-interfaces/0/ ParentInterface string // Vlan is the vlan id. Vlan int // MTU is the vlan's MTU value. MTU int // IP is the vlan's ip address. IP string // IPv6 is the vlan's ipv6 address. IPv6 []string // Gateway is the vlan's gateway address. Gateway string // GatewayIPv6 is the vlan's IPv6 gateway address. GatewayIPv6 string // DHCPv6Refresh determine if VLAN NIC supports IPV6. DHCPv6Refresh string } // Project describes the projects instance's attributes. type Project struct { Attributes Attributes ProjectID string NumericProjectID json.Number } // Attributes describes the project's attributes keys. type Attributes struct { CreatedBy string BlockProjectKeys bool HTTPSMDSEnableNativeStore *bool DisableHTTPSMdsSetup *bool EnableOSLogin *bool EnableWindowsSSH *bool TwoFactor *bool SecurityKey *bool RequireCerts *bool SSHKeys []string WindowsKeys WindowsKeys Diagnostics string DisableAddressManager *bool DisableAccountManager *bool EnableDiagnostics *bool EnableWSFC *bool WSFCAddresses string WSFCAgentPort string DisableTelemetry bool } // UnmarshalJSON unmarshals b into Attribute. func (a *Attributes) UnmarshalJSON(b []byte) error { var mkbool = func(value bool) *bool { res := new(bool) *res = value return res } // Unmarshal to literal JSON types before doing anything else. type inner struct { CreatedBy string `json:"created-by"` BlockProjectKeys string `json:"block-project-ssh-keys"` Diagnostics string `json:"diagnostics"` DisableAccountManager string `json:"disable-account-manager"` DisableAddressManager string `json:"disable-address-manager"` EnableDiagnostics string `json:"enable-diagnostics"` EnableOSLogin string `json:"enable-oslogin"` EnableWindowsSSH string `json:"enable-windows-ssh"` EnableWSFC string `json:"enable-wsfc"` OldSSHKeys string `json:"sshKeys"` SSHKeys string `json:"ssh-keys"` TwoFactor string `json:"enable-oslogin-2fa"` SecurityKey string `json:"enable-oslogin-sk"` RequireCerts string `json:"enable-oslogin-certificates"` WindowsKeys WindowsKeys `json:"windows-keys"` WSFCAddresses string `json:"wsfc-addrs"` WSFCAgentPort string `json:"wsfc-agent-port"` DisableTelemetry string `json:"disable-guest-telemetry"` DisableHTTPSMdsSetup string `json:"disable-https-mds-setup"` HTTPSMDSEnableNativeStore string `json:"enable-https-mds-native-cert-store"` } var temp inner if err := json.Unmarshal(b, &temp); err != nil { return err } a.Diagnostics = temp.Diagnostics a.WSFCAddresses = temp.WSFCAddresses a.WSFCAgentPort = temp.WSFCAgentPort a.WindowsKeys = temp.WindowsKeys a.CreatedBy = temp.CreatedBy value, err := strconv.ParseBool(temp.DisableHTTPSMdsSetup) if err == nil { a.DisableHTTPSMdsSetup = mkbool(value) } value, err = strconv.ParseBool(temp.HTTPSMDSEnableNativeStore) if err == nil { a.HTTPSMDSEnableNativeStore = mkbool(value) } value, err = strconv.ParseBool(temp.BlockProjectKeys) if err == nil { a.BlockProjectKeys = value } value, err = strconv.ParseBool(temp.EnableDiagnostics) if err == nil { a.EnableDiagnostics = mkbool(value) } value, err = strconv.ParseBool(temp.DisableAccountManager) if err == nil { a.DisableAccountManager = mkbool(value) } value, err = strconv.ParseBool(temp.DisableAddressManager) if err == nil { a.DisableAddressManager = mkbool(value) } value, err = strconv.ParseBool(temp.EnableOSLogin) if err == nil { a.EnableOSLogin = mkbool(value) } value, err = strconv.ParseBool(temp.EnableWindowsSSH) if err == nil { a.EnableWindowsSSH = mkbool(value) } value, err = strconv.ParseBool(temp.EnableWSFC) if err == nil { a.EnableWSFC = mkbool(value) } value, err = strconv.ParseBool(temp.TwoFactor) if err == nil { a.TwoFactor = mkbool(value) } value, err = strconv.ParseBool(temp.SecurityKey) if err == nil { a.SecurityKey = mkbool(value) } value, err = strconv.ParseBool(temp.RequireCerts) if err == nil { a.RequireCerts = mkbool(value) } value, err = strconv.ParseBool(temp.DisableTelemetry) if err == nil { a.DisableTelemetry = value } // So SSHKeys will be nil instead of []string{} if temp.SSHKeys != "" { a.SSHKeys = strings.Split(temp.SSHKeys, "\n") } if temp.OldSSHKeys != "" { a.BlockProjectKeys = true a.SSHKeys = append(a.SSHKeys, strings.Split(temp.OldSSHKeys, "\n")...) } return nil } func (c *Client) updateEtag(resp *http.Response) bool { oldEtag := c.etag c.etag = resp.Header.Get("etag") if c.etag == "" { c.etag = defaultEtag } return c.etag != oldEtag } // MDSReqError represents custom error produced by HTTP requests made on MDS. It captures // error and HTTP response for inspecting status code. type MDSReqError struct { status int err error } // Error implements method defined on error interface to transform custom type into error. func (m *MDSReqError) Error() string { return fmt.Sprintf("request failed with status code: [%d], error: [%v]", m.status, m.err) } // shouldRetry method checks if MDSReqError is temporary and retriable or not. func shouldRetry(err error) bool { e, ok := err.(*MDSReqError) if !ok { // Unknown error retry. return true } // Known non-retriable status codes. codes := []int{404} return !slices.Contains(codes, e.status) } func (c *Client) retry(ctx context.Context, cfg requestConfig) (string, error) { policy := retry.Policy{MaxAttempts: backoffAttempts, Jitter: backoffDuration, BackoffFactor: 1, ShouldRetry: shouldRetry} fn := func() (string, error) { resp, err := c.do(ctx, cfg) if err != nil { statusCode := -1 if resp != nil { statusCode = resp.StatusCode } return "", &MDSReqError{statusCode, err} } defer resp.Body.Close() md, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("failed to read metadata server response bytes: %+v", err) } return string(md), nil } return retry.RunWithResponse(ctx, policy, fn) } // GetKey gets a specific metadata key. func (c *Client) GetKey(ctx context.Context, key string, headers map[string]string) (string, error) { reqURL, err := url.JoinPath(c.metadataURL, key) if err != nil { return "", fmt.Errorf("failed to form metadata url: %+v", err) } cfg := requestConfig{ baseURL: reqURL, headers: headers, } return c.retry(ctx, cfg) } // GetKeyRecursive gets a specific metadata key recursively and returns JSON output. func (c *Client) GetKeyRecursive(ctx context.Context, key string) (string, error) { reqURL, err := url.JoinPath(c.metadataURL, key) if err != nil { return "", fmt.Errorf("failed to form metadata url: %+v", err) } cfg := requestConfig{ baseURL: reqURL, jsonOutput: true, recursive: true, } return c.retry(ctx, cfg) } // Watch runs a longpoll on metadata server. func (c *Client) Watch(ctx context.Context) (*Descriptor, error) { return c.get(ctx, true) } // Get does a metadata call, if hang is set to true then it will do a longpoll. func (c *Client) Get(ctx context.Context) (*Descriptor, error) { return c.get(ctx, false) } func (c *Client) get(ctx context.Context, hang bool) (*Descriptor, error) { cfg := requestConfig{ baseURL: c.metadataURL, timeout: defaultHangTimeout, recursive: true, jsonOutput: true, } if hang { cfg.hang = true } resp, err := c.retry(ctx, cfg) if err != nil { return nil, err } var ret Descriptor if err = json.Unmarshal([]byte(resp), &ret); err != nil { return nil, err } return &ret, nil } // WriteGuestAttributes does a put call to mds changing a guest attribute value. func (c *Client) WriteGuestAttributes(ctx context.Context, key, value string) error { logger.Debugf("write guest attribute %q", key) finalURL, err := url.JoinPath(c.metadataURL, "instance/guest-attributes/", key) if err != nil { return fmt.Errorf("failed to form metadata url: %+v", err) } logger.Debugf("Requesting(PUT) MDS URL: %s", finalURL) // This is a arbitrary retry number. policy := retry.Policy{MaxAttempts: 10, Jitter: backoffDuration, BackoffFactor: 1} putCall := func() error { req, err := http.NewRequest("PUT", finalURL, strings.NewReader(value)) if err != nil { return err } req.Header.Add("Metadata-Flavor", "Google") req = req.WithContext(ctx) _, err = c.httpClient.Do(req) return err } return retry.Run(ctx, policy, putCall) } func (c *Client) do(ctx context.Context, cfg requestConfig) (*http.Response, error) { finalURL, err := url.Parse(cfg.baseURL) if err != nil { return nil, fmt.Errorf("failed to parse url: %+v", err) } values := finalURL.Query() if cfg.hang { values.Add("wait_for_change", "true") values.Add("last_etag", c.etag) } if cfg.timeout > 0 { values.Add("timeout_sec", fmt.Sprintf("%d", cfg.timeout)) } if cfg.recursive { values.Add("recursive", "true") } if cfg.jsonOutput { values.Add("alt", "json") } finalURL.RawQuery = values.Encode() logger.Debugf("Requesting(GET) MDS URL: %s", finalURL.String()) req, err := http.NewRequestWithContext(ctx, "GET", finalURL.String(), nil) if err != nil { return nil, err } req.Header.Add("Metadata-Flavor", "Google") for k, v := range cfg.headers { req.Header.Add(k, v) } resp, err := c.httpClient.Do(req) // If we are canceling httpClient will also wrap the context's error so // check first the context. if ctx.Err() != nil { return resp, ctx.Err() } if err != nil { return resp, fmt.Errorf("error connecting to metadata server: %+v", err) } if resp == nil { return nil, fmt.Errorf("got nil response from metadata server") } if resp.StatusCode != http.StatusOK { defer resp.Body.Close() // Ignore read error as we are returning original error and wrapping MDS error code. r, _ := io.ReadAll(resp.Body) return resp, fmt.Errorf("invalid response from metadata server, status code: %d, reason: %s", resp.StatusCode, string(r)) } if cfg.hang { c.updateEtag(resp) } return resp, nil }