internal/acs/client/client.go (225 lines of code) (raw):

// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package client provides client methods to communicate with ACS. package client import ( "context" "fmt" "sync" "sync/atomic" "time" client "github.com/GoogleCloudPlatform/agentcommunication_client" "github.com/GoogleCloudPlatform/agentcommunication_client/gapic" acpb "github.com/GoogleCloudPlatform/agentcommunication_client/gapic/agentcommunicationpb" "github.com/GoogleCloudPlatform/galog" acmpb "github.com/GoogleCloudPlatform/google-guest-agent/internal/acp/proto/google_guest_agent/acp" "github.com/GoogleCloudPlatform/google-guest-agent/internal/cfg" "github.com/GoogleCloudPlatform/google-guest-agent/internal/retry" "google.golang.org/api/option" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/proto" anypb "google.golang.org/protobuf/types/known/anypb" ) const ( // messageType is key in labels for message type. messageType = "message_type" // defaultAgentChannelID is the default channel ID used by Guest Agent for ACS. defaultAgentChannelID = "compute.googleapis.com/google-guest-agent" // pluginEventMessageMsg is the message type label to use with any event notifications send by agent. pluginEventMessageMsg = "agent_controlplane.PluginEventMessage" ) var ( // Default retry policy to retry max 5 times for upto ~30 sec. defaultRetrypolicy = retry.Policy{MaxAttempts: 5, Jitter: time.Second, BackoffFactor: 2} // Watcher retry policy to retry indefinitely. Watcher could fail if there is // no connection to ACS (network issue), could be a transient error or unable // to retrieve the identity token. Default retry policy will be too frequent // and will not achieve much in this case but overwhelm instance by consuming // too much CPU/memory. watcherRetrypolicy = retry.Policy{Jitter: time.Second * 2, BackoffFactor: 2, MaximumBackoff: time.Minute * 20} // acs manages ACS connection and implements all ACS client methods. acs = &acsHelper{} ) // ConnectionInterface is the minimum interface required by Agent to communicate // with ACS. type ConnectionInterface interface { SendMessage(msg *acpb.MessageBody) error Receive() (*acpb.MessageBody, error) Close() } // acsHelper is a helper struct to manage ACS connection. type acsHelper struct { // channelID is the channel ID used by this ACS connection. channelID string // connMu protects connection. connMu sync.Mutex // connection is cached ACS Connection, reuse across Agent if its already set. // connection uses the acsClient to send and receive messages. connection ConnectionInterface // client is a client for interacting with Agent Communication API. client *agentcommunication.Client // isConnectionSet is set if connection is initialized and still valid. // Its reset to false in error case to trigger reconnection. isConnectionSet atomic.Bool } // ContextKey is the context key type to use for overriding. type ContextKey string // OverrideConnection is the key for context to override client connection. const OverrideConnection ContextKey = "override_connection" // channelID returns channel ID to use for ACS connections. The channel ID can // be overridden by a config entry. func channelID() string { override := cfg.Retrieve().ACS.ChannelID if override != "" { galog.Debugf("Using overridden ACS channel ID: %q", override) return override } return defaultAgentChannelID } // clientOptions returns client options for creating ACS connection. // The endpoint or host address option can be overridden by a config entry. func clientOptions(ctx context.Context) ([]option.ClientOption, error) { var opts []option.ClientOption endpoint := cfg.Retrieve().ACS.Endpoint if endpoint != "" { galog.Debugf("Using overridden ACS endpoint: %q", endpoint) opts = append(opts, option.WithEndpoint(endpoint)) } addr := cfg.Retrieve().ACS.Host if addr == "" { return opts, nil } addr = "unix:" + addr galog.Debugf("Using overridden ACS server address: %q", addr) options := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), } conn, err := grpc.NewClient(addr, options...) if err != nil { return nil, fmt.Errorf("unable to dial ACS server, err: %w", err) } opts = append(opts, option.WithGRPCConn(conn)) return opts, nil } // Close closes the ACS connection and client. func (acs *acsHelper) close(ctx context.Context) { galog.Debugf("Closing ACS connection") acs.connMu.Lock() defer acs.connMu.Unlock() // Reset connection state, this will trigger to create a new ACS connection // on next send/receive. acs.isConnectionSet.Store(false) acs.connection.Close() if ctx.Value(OverrideConnection) != nil { // Skip closing the client in override mode, it might not be set. return } if err := acs.client.Close(); err != nil { galog.V(2).Warnf("Failed to close ACS client: %v", err) } } // shouldReconnect returns true if the new acs connection needs to be created. func (acs *acsHelper) shouldReconnect() bool { return !acs.isConnectionSet.Load() || (acs.client == nil) || isNilInterface(acs.connection) } // connect creates new stream and connection to ACS and sets the client. // New connection is created if found nil or [isConnectionSet] is unset. // isConnectionSet can be unset to trigger new connection creation which is // generally required on send/receive error. func (acs *acsHelper) connect(ctx context.Context) error { acs.connMu.Lock() defer acs.connMu.Unlock() // Used of unit testing. Create connection makes a http call which must be avoided in unit tests. if ctx.Value(OverrideConnection) != nil { acs.connection = ctx.Value(OverrideConnection).(ConnectionInterface) // Override max retry attempts for unit tests to avoid spending too much // time waiting in retries. defaultRetrypolicy.MaxAttempts = 2 return nil } if !acs.shouldReconnect() { return nil } galog.Debugf("Creating new ACS connection") acs.channelID = channelID() opts, err := clientOptions(ctx) if err != nil { return fmt.Errorf("unable to get client options, err: %w", err) } acs.client, err = client.NewClient(ctx, false, opts...) if err != nil { return fmt.Errorf("unable to create new ACS client, err: %w", err) } acs.connection, err = client.NewConnection(ctx, acs.channelID, acs.client) if err != nil { return fmt.Errorf("unable to create new ACS connection, err: %w", err) } acs.isConnectionSet.Store(true) galog.Debugf("Created ACS connection") return nil } // sendStream sends a message to ACS stream. func (acs *acsHelper) sendStream(ctx context.Context, labels map[string]string, msg *anypb.Any) error { // [connection] will be nil if [CreateConnection] ever fails. In that case // simply retrying will result into [SIGSEGV]. Make sure connection is set // before sending message. if err := acs.connect(ctx); err != nil { return fmt.Errorf("unable to set connection for sending msg, err: %w", err) } err := acs.connection.SendMessage(&acpb.MessageBody{Labels: labels, Body: msg}) if err != nil { // Close connection if error occurs, this triggers to create a new ACS // connection and ensures the previous connection is closed. acs.close(ctx) } return err } // sendAgentMessage is a wrapper around client.SendAgentMessage to allow for // mocking in unit tests. var sendAgentMessage = func(ctx context.Context, channelID string, acsClient *agentcommunication.Client, msg *acpb.MessageBody) (*acpb.SendAgentMessageResponse, error) { return client.SendAgentMessage(ctx, channelID, acsClient, msg) } // sendStream sends a message and wait for response from ACS. func (acs *acsHelper) sendMessage(ctx context.Context, labels map[string]string, msg *anypb.Any) (*acpb.SendAgentMessageResponse, error) { if err := acs.connect(ctx); err != nil { return nil, fmt.Errorf("unable to connect for sending msg, err: %w", err) } resp, err := sendAgentMessage(ctx, acs.channelID, acs.client, &acpb.MessageBody{Labels: labels, Body: msg}) if err != nil { acs.close(ctx) } return resp, err } func (acs *acsHelper) receiveStream(ctx context.Context) (*acpb.MessageBody, error) { if err := acs.connect(ctx); err != nil { return nil, fmt.Errorf("unable to connect for watching new messages on channel, err: %w", err) } msg, err := acs.connection.Receive() if err != nil { acs.close(ctx) return nil, err } return msg, err } // SendMessage sends a message to ACS and waits for response. This is normally // used for sending metrics to ACS. func SendMessage(ctx context.Context, labels map[string]string, msg proto.Message) (*acpb.SendAgentMessageResponse, error) { if !cfg.Retrieve().Core.ACSClient { galog.V(2).Debugf("ACS client is disabled, ignoring send message request %v", msg) return nil, nil } anyMsg, err := anypb.New(msg) if err != nil { return nil, fmt.Errorf("unable to marshal message, err: %w", err) } fn := func() (*acpb.SendAgentMessageResponse, error) { resp, err := acs.sendMessage(ctx, labels, anyMsg) if err != nil { return nil, fmt.Errorf("unable to send message, err: %w", err) } galog.V(2).Debugf("Successfully sent message [%s] to ACS, resp: %s", anyMsg.String(), resp.String()) return resp, nil } return retry.RunWithResponse(ctx, defaultRetrypolicy, fn) } // Notify sends an event notification on ACS. func Notify(ctx context.Context, event *acmpb.PluginEventMessage) error { return Send(ctx, map[string]string{messageType: pluginEventMessageMsg}, event) } // Send sends a message on ACS. func Send(ctx context.Context, labels map[string]string, msg proto.Message) error { if !cfg.Retrieve().Core.ACSClient { galog.V(2).Debugf("ACS client is disabled, ignoring message %v", msg) return nil } anyMsg, err := anypb.New(msg) if err != nil { return fmt.Errorf("unable to marshal message, err: %w", err) } fn := func() error { if err := acs.sendStream(ctx, labels, anyMsg); err != nil { return fmt.Errorf("unable to send message on stream, err: %w", err) } galog.V(2).Debugf("Successfully sent message [%s] to ACS", anyMsg.String()) return nil } return retry.Run(ctx, defaultRetrypolicy, fn) } // Watch checks for a new message from ACS and returns. func Watch(ctx context.Context) (*acpb.MessageBody, error) { if !cfg.Retrieve().Core.ACSClient { galog.V(2).Debugf("ACS client is disabled, ignoring watch request") return nil, nil } fn := func() (*acpb.MessageBody, error) { msg, err := acs.receiveStream(ctx) if err != nil { return nil, fmt.Errorf("unable to listen on stream for new message, err: %w", err) } return msg, nil } return retry.RunWithResponse(ctx, watcherRetryPolicy(ctx), fn) } func watcherRetryPolicy(ctx context.Context) retry.Policy { if ctx.Value(OverrideConnection) != nil { // Override max retry attempts for unit tests to avoid spending too much // time waiting in retries. return retry.Policy{MaxAttempts: 2, Jitter: time.Millisecond, BackoffFactor: 1} } return watcherRetrypolicy } // isNilInterface returns true if the given interface is nil or of unexpected // type. Go’s interfaces are a pair of pointers. One points to the type and other // points to the underlying value. So, an interface is only considered nil when // both its type and value are nil. This helper method helps do a nil check // for [ConnectionInterface]. func isNilInterface(a any) bool { if a == nil { return true } v, ok := a.(*client.Connection) if !ok { galog.Debugf("Interface %v (%T) is not of type client.Connection", a, a) return true } return v == nil }