client.go (502 lines of code) (raw):

// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package client is an AgentCommunication client library. package client import ( "context" "errors" "fmt" "io" "log" "os" "strconv" "strings" "sync" "time" "github.com/GoogleCloudPlatform/agentcommunication_client/gapic" "google.golang.org/api/option" "github.com/google/uuid" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" acpb "github.com/GoogleCloudPlatform/agentcommunication_client/gapic/agentcommunicationpb" ) func init() { logger = log.New(os.Stderr, "", log.Ldate|log.Ltime|log.Lshortfile) } const ( metadataMessageRateLimit = "agent-communication-message-rate-limit" metadataBandwidthLimit = "agent-communication-bandwidth-limit" ) var ( // DebugLogging enables debug logging. DebugLogging = false // ErrConnectionClosed is an error indicating that the connection was closed by the caller. ErrConnectionClosed = errors.New("connection closed") // ErrMessageTimeout is an error indicating message send timed out. ErrMessageTimeout = errors.New("timed out waiting for response") // ErrResourceExhausted is an error indicating that the server responded to the send with // ResourceExhausted. ErrResourceExhausted = errors.New("resource exhausted") // ErrGettingInstanceToken is an error indicating that the instance token could not be retrieved. ErrGettingInstanceToken = errors.New("error getting instance token") // defaultOpts are the default options for used for creating ACS clients. defaultOpts = []option.ClientOption{ option.WithoutAuthentication(), // Do not use oauth. option.WithGRPCDialOption(grpc.WithTransportCredentials(credentials.NewTLS(nil))), // Because we disabled Auth we need to specifically enable TLS. option.WithGRPCDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{Time: 60 * time.Second, Timeout: 10 * time.Second})), } logger *log.Logger ) func loggerPrintf(format string, v ...any) { if DebugLogging { logger.Output(2, fmt.Sprintf(format, v...)) } } func getEndpoint(regional bool) (string, error) { zone, err := getZone() if err != nil { return "", err } location := zone if regional { index := strings.LastIndex(location, "-") if index == -1 { return "", fmt.Errorf("zone %q is not a valid zone", zone) } location = location[:index] } return fmt.Sprintf("%s-agentcommunication.googleapis.com.:443", location), nil } // NewClient creates a new agent communication grpc client. // Caller must close the returned client when it is done being used to clean up its underlying // connections. func NewClient(ctx context.Context, regional bool, opts ...option.ClientOption) (*agentcommunication.Client, error) { endpoint, err := getEndpoint(regional) if err != nil { return nil, err } optsWithEndpoint := append(defaultOpts, option.WithEndpoint(endpoint)) return agentcommunication.NewClient(ctx, append(optsWithEndpoint, opts...)...) } // SendAgentMessage sends a message to the client. This is equivalent to sending a message via // StreamAgentMessages with a single message and waiting for the response. func SendAgentMessage(ctx context.Context, channelID string, client *agentcommunication.Client, msg *acpb.MessageBody) (*acpb.SendAgentMessageResponse, error) { loggerPrintf("SendAgentMessage") resourceID, err := getResourceID() if err != nil { return nil, err } token, err := getIdentityToken() if err != nil { return nil, fmt.Errorf("%w: %v", ErrGettingInstanceToken, err) } ctx = metadata.NewOutgoingContext(ctx, metadata.New(map[string]string{ "authentication": "Bearer " + token, "agent-communication-resource-id": resourceID, "agent-communication-channel-id": channelID, })) loggerPrintf("Using ResourceID %q", resourceID) loggerPrintf("Using ChannelID %q", channelID) return client.SendAgentMessage(ctx, &acpb.SendAgentMessageRequest{ ChannelId: channelID, ResourceId: resourceID, MessageBody: msg, }) } // Connection is an AgentCommunication connection. type Connection struct { client *agentcommunication.Client // Indicates that the client is caller managed and should not be closed. callerManagedClient bool // Indicates that the entire connection is closed and will not reopen. closed chan struct{} closeErr error closeErrMx sync.RWMutex // Indicates that the underlying stream is ready to send. streamReady chan struct{} sends chan *acpb.StreamAgentMessagesRequest resourceID string channelID string messages chan *acpb.MessageBody responseSubs map[string]chan *status.Status responseMx sync.Mutex regional bool timeToWaitForResp time.Duration limitsMx sync.Mutex messageRateLimit int messageBandwidthLimit int } // MessageRateLimit returns the message limit (in messages/minute) for the connection. func (c *Connection) MessageRateLimit() int { c.limitsMx.Lock() defer c.limitsMx.Unlock() return c.messageRateLimit } // MessageBandwidthLimit returns the message bandwidth limit (in bytes/minute) for the connection. func (c *Connection) MessageBandwidthLimit() int { c.limitsMx.Lock() defer c.limitsMx.Unlock() return c.messageBandwidthLimit } // Close the connection. func (c *Connection) Close() { c.close(ErrConnectionClosed) } func (c *Connection) setCloseErr(err error) { c.closeErrMx.Lock() defer c.closeErrMx.Unlock() c.closeErr = err } func (c *Connection) getCloseErr() error { c.closeErrMx.RLock() defer c.closeErrMx.RUnlock() return c.closeErr } func (c *Connection) close(err error) { loggerPrintf("closing connection with err: %v", err) st, _ := status.FromError(err) loggerPrintf("closing connection with status: %+v", st) select { case <-c.closed: return default: close(c.closed) c.setCloseErr(err) if !c.callerManagedClient { c.client.Close() } } } func (c *Connection) waitForResponse(key string, channel chan *status.Status) error { defer func() { c.responseMx.Lock() delete(c.responseSubs, key) c.responseMx.Unlock() }() timer := time.NewTimer(c.timeToWaitForResp) defer timer.Stop() select { case st := <-channel: if st != nil { switch st.Code() { case codes.OK: case codes.ResourceExhausted: return fmt.Errorf("%w: %s", ErrResourceExhausted, st.Message()) default: return fmt.Errorf("unexpected status: %+v", st) } } case <-timer.C: return fmt.Errorf("%w: timed out waiting for response, MessageID: %q", ErrMessageTimeout, key) case <-c.closed: return fmt.Errorf("connection closed with err: %w", c.getCloseErr()) } return nil } func (c *Connection) sendWithResp(req *acpb.StreamAgentMessagesRequest, channel chan *status.Status) error { loggerPrintf("Sending message %+v", req) select { case <-c.closed: return fmt.Errorf("connection closed with err: %w", c.getCloseErr()) case c.sends <- req: } return c.waitForResponse(req.GetMessageId(), channel) } func (c *Connection) sendMessage(msg *acpb.MessageBody) error { req := &acpb.StreamAgentMessagesRequest{ MessageId: uuid.New().String(), Type: &acpb.StreamAgentMessagesRequest_MessageBody{MessageBody: msg}, } channel := make(chan *status.Status) c.responseMx.Lock() c.responseSubs[req.GetMessageId()] = channel c.responseMx.Unlock() select { case <-c.closed: return fmt.Errorf("connection closed with err: %w", c.getCloseErr()) case c.streamReady <- struct{}{}: // Only sends if the stream is ready to send. } return c.sendWithResp(req, channel) } // SendMessage sends a message to the client. Will automatically retry on message timeout (temporary // disconnects) and in the case of ResourceExhausted with a backoff. Because retries are limited // the returned error can in some cases be one of ErrMessageTimeout or ErrResourceExhausted, in // which case send should be retried by the caller. func (c *Connection) SendMessage(msg *acpb.MessageBody) error { var err error // Retry 4 times. for i := 1; i <= 5; i++ { err := c.sendMessage(msg) if errors.Is(err, ErrResourceExhausted) { // Start with 250ms sleep, then simply multiply by iteration. time.Sleep(time.Duration(i*250) * time.Millisecond) continue } else if errors.Is(err, ErrMessageTimeout) { continue } return err } return err } // Receive messages, Receive should be called continuously for the life of the stream connection, // any delay in Receive when there are queued messages will cause the server to disconnect the // stream. This means handling the MessageBody from Receive should not be blocking, offload message // handling to another goroutine and immediately call Receive again. func (c *Connection) Receive() (*acpb.MessageBody, error) { select { case msg := <-c.messages: return msg, nil case <-c.closed: return nil, fmt.Errorf("connection closed with err: %w", c.getCloseErr()) } } func (c *Connection) streamSend(req *acpb.StreamAgentMessagesRequest, streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) error { select { case <-streamClosed: return errors.New("stream closed") case streamSendLock <- struct{}{}: defer func() { <-streamSendLock }() } if err := stream.Send(req); err != nil { if err != io.EOF && !errors.Is(err, io.EOF) { // Something is very broken, just close the stream here. loggerPrintf("Unexpected send error, closing connection: %v", err) c.close(err) return err } // EOF error means the stream is closed, this should be picked up by recv, but that could be // blocked, close our sends for now and just allow the caller handle it, SendMessage will wait // for response which will never come and auto retry. acknowledgeMessage will fail and prevent // the message from being passed on to message handlers, allowing recv to handle the stream // close error. loggerPrintf("Error sending message, stream closed.") select { case <-streamClosed: default: close(streamClosed) } return ErrConnectionClosed } return nil } func (c *Connection) send(streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) { defer func() { // Lock the stream sends so we can close the stream. streamSendLock <- struct{}{} stream.CloseSend() }() for { select { case req := <-c.sends: if err := c.streamSend(req, streamClosed, streamSendLock, stream); err != nil { return } case <-c.closed: return case <-streamClosed: return } } } func (c *Connection) acknowledgeMessage(messageID string, streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) error { ackReq := &acpb.StreamAgentMessagesRequest{ MessageId: messageID, Type: &acpb.StreamAgentMessagesRequest_MessageResponse{}, } select { case <-c.closed: return fmt.Errorf("connection closed with err: %w", c.closeErr) default: return c.streamSend(ackReq, streamClosed, streamSendLock, stream) } } func (c *Connection) readHeaders(stream acpb.AgentCommunication_StreamAgentMessagesClient) { md, err := stream.Header() if err != nil { loggerPrintf("Error getting stream header: %v", err) return } c.limitsMx.Lock() defer c.limitsMx.Unlock() value := md.Get(metadataMessageRateLimit) if len(value) > 0 { rateLimit, err := strconv.Atoi(value[0]) if err != nil { loggerPrintf("Error parsing message rate limit: %v", err) } else { c.messageRateLimit = rateLimit loggerPrintf("Message rate limit: %d", c.messageRateLimit) } } else { loggerPrintf("No message rate limit") } value = md.Get(metadataBandwidthLimit) if len(value) > 0 { bandwidthLimit, err := strconv.Atoi(value[0]) if err != nil { loggerPrintf("Error parsing message bandwidth limit: %v", err) } else { c.messageBandwidthLimit = bandwidthLimit loggerPrintf("Message bandwidth limit: %d", c.messageBandwidthLimit) } } else { loggerPrintf("No message bandwidth limit") } } // recv keeps receiving and acknowledging new messages. func (c *Connection) recv(ctx context.Context, streamClosed, streamSendLock chan struct{}, stream acpb.AgentCommunication_StreamAgentMessagesClient) { loggerPrintf("Receiving messages") for { resp, err := stream.Recv() if err != nil { select { case <-streamClosed: default: // Causes the send goroutine to exit, c.sends will block. close(streamClosed) } select { case <-c.closed: // Connection is closed, return now. loggerPrintf("Connection closed, recv returning") return default: } st, ok := status.FromError(err) if ok && st.Code() == codes.ResourceExhausted { loggerPrintf("Connection closed due to resource exhausted: %v", err) } else if ok && st.Code() == codes.Unavailable { loggerPrintf("Stream returned Unavailable, will reconnect: %v", err) } else if err != io.EOF && !errors.Is(err, io.EOF) && (ok && st.Code() != codes.Canceled) && (ok && st.Code() != codes.DeadlineExceeded) { // EOF is a normal stream close, Canceled will be set by the server when stream timeout is // reached, DeadlineExceeded would be because of the client side deadline we set. loggerPrintf("Unexpected error, closing connection: %v", err) c.close(err) return } // A new stream is created if: // 1. Resource exhausted is returned but we have not exceeded the max number of retries. // 2. Unavailable is returned but we have not exceeded the max number of retries. // 3. A known "normal disconnect" error is returned. loggerPrintf("Creating new stream") if err := c.createStream(ctx); err != nil { loggerPrintf("Error creating new stream: %v", err) c.close(err) } // Always return here, createStream launches a new recv goroutine. return } switch resp.GetType().(type) { case *acpb.StreamAgentMessagesResponse_MessageBody: // Acknowledge message first, if this ack fails dont forward the message on to the handling // logic since that indicates a stream disconnect. if err := c.acknowledgeMessage(resp.GetMessageId(), streamClosed, streamSendLock, stream); err != nil { loggerPrintf("Error acknowledging message %q: %v", resp.GetMessageId(), err) continue } c.messages <- resp.GetMessageBody() case *acpb.StreamAgentMessagesResponse_MessageResponse: st := resp.GetMessageResponse().GetStatus() c.responseMx.Lock() for key, sub := range c.responseSubs { if key != resp.GetMessageId() { continue } select { case sub <- status.FromProto(st): default: } } c.responseMx.Unlock() } } } func createStreamLoop(ctx context.Context, client *agentcommunication.Client, resourceID string, channelID string) (acpb.AgentCommunication_StreamAgentMessagesClient, error) { resourceExhaustedRetries := 0 unavailableRetries := 0 for { stream, err := client.StreamAgentMessages(ctx) if err != nil { return nil, fmt.Errorf("error creating stream: %v", err) } // RegisterConnection is a special message that must be sent before any other messages. req := &acpb.StreamAgentMessagesRequest{ MessageId: uuid.New().String(), Type: &acpb.StreamAgentMessagesRequest_RegisterConnection{ RegisterConnection: &acpb.RegisterConnection{ResourceId: resourceID, ChannelId: channelID}}} if err := stream.Send(req); err != nil { return nil, fmt.Errorf("error sending register connection: %v", err) } // We expect the first message to be a MessageResponse. resp, err := stream.Recv() if err == nil { switch resp.GetType().(type) { case *acpb.StreamAgentMessagesResponse_MessageResponse: if resp.GetMessageResponse().GetStatus().GetCode() != int32(codes.OK) { return nil, fmt.Errorf("unexpected register response: %+v", resp.GetMessageResponse().GetStatus()) } } // Stream is connected. return stream, nil } st, ok := status.FromError(err) if ok && st.Code() == codes.ResourceExhausted { loggerPrintf("Resource exhausted, sleeping before reconnect: %v", err) if resourceExhaustedRetries > 20 { loggerPrintf("Stream returned ResourceExhausted, exceeded max number of reconnects, closing connection: %v", err) } sleep := time.Duration(resourceExhaustedRetries+1) * time.Second if resourceExhaustedRetries > 9 { sleep = 10 * time.Second } time.Sleep(sleep) resourceExhaustedRetries++ continue } else if ok && st.Code() == codes.Unavailable { // Retry max 5 times (2s total). if unavailableRetries <= 5 { loggerPrintf("Stream returned Unavailable, will reconnect: %v", err) // Sleep for 200ms * num of unavailableRetries, first retry is immediate. time.Sleep(time.Duration(unavailableRetries*200) * time.Millisecond) unavailableRetries++ continue } loggerPrintf("Stream returned Unavailable, exceeded max number of reconnects, closing connection: %v", err) } return nil, err } } func (c *Connection) createStream(ctx context.Context) error { loggerPrintf("Creating stream.") token, err := getIdentityToken() if err != nil { return fmt.Errorf("%w: %v", ErrGettingInstanceToken, err) } ctx = metadata.NewOutgoingContext(ctx, metadata.New(map[string]string{ "authentication": "Bearer " + token, "agent-communication-resource-id": c.resourceID, "agent-communication-channel-id": c.channelID, })) loggerPrintf("Using ResourceID %q", c.resourceID) loggerPrintf("Using ChannelID %q", c.channelID) // Set a timeout for the stream, this is well above service side timeout. cnclCtx, cancel := context.WithTimeout(ctx, 60*time.Minute) stream, err := createStreamLoop(cnclCtx, c.client, c.resourceID, c.channelID) if err != nil { cancel() c.close(err) return err } // Reading headers is best effort, if we fail to read headers we will just log the error and // continue. c.readHeaders(stream) // Used to signal that the stream is closed. streamClosed := make(chan struct{}) // This ensures that only one send is happening at a time. streamSendLock := make(chan struct{}, 1) go c.recv(ctx, streamClosed, streamSendLock, stream) go c.send(streamClosed, streamSendLock, stream) go func() { defer cancel() for { select { // Indicates that the stream is setup and is ready to send, this is used by sendMessage to // block sends during reconnect. case <-c.streamReady: case <-streamClosed: return } } }() loggerPrintf("Stream established.") return nil } // NewConnection creates a new streaming connection. // Caller is responsible for calling Close() on the connection when done, certain errors will cause // the connection to be closed automatically. The passed in client will not be closed and can be // reused. func NewConnection(ctx context.Context, channelID string, client *agentcommunication.Client) (*Connection, error) { conn := &Connection{ channelID: channelID, closed: make(chan struct{}), messages: make(chan *acpb.MessageBody), responseSubs: make(map[string]chan *status.Status), streamReady: make(chan struct{}), sends: make(chan *acpb.StreamAgentMessagesRequest), timeToWaitForResp: 2 * time.Second, client: client, callerManagedClient: true, } var err error conn.resourceID, err = getResourceID() if err != nil { return nil, err } if err := conn.createStream(ctx); err != nil { conn.close(err) return nil, err } return conn, nil } // CreateConnection creates a new connection. // DEPRECATED: Use NewConnection instead. func CreateConnection(ctx context.Context, channelID string, regional bool, opts ...option.ClientOption) (*Connection, error) { conn := &Connection{ channelID: channelID, closed: make(chan struct{}), messages: make(chan *acpb.MessageBody), responseSubs: make(map[string]chan *status.Status), streamReady: make(chan struct{}), sends: make(chan *acpb.StreamAgentMessagesRequest), timeToWaitForResp: 2 * time.Second, } var err error conn.resourceID, err = getResourceID() if err != nil { return nil, err } conn.client, err = NewClient(ctx, regional, opts...) if err != nil { return nil, err } if err := conn.createStream(ctx); err != nil { conn.close(err) return nil, err } return conn, nil }