ecs-agent/acs/session/session.go (362 lines of code) (raw):
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.
// Package session deals with appropriately reacting to all ACS messages as well
// as maintaining the connection to ACS.
package session
import (
"context"
"io"
"net/url"
"strconv"
"strings"
"time"
"github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs"
rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
"github.com/aws/amazon-ecs-agent/ecs-agent/doctor"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/field"
"github.com/aws/amazon-ecs-agent/ecs-agent/metrics"
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry"
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime"
"github.com/aws/amazon-ecs-agent/ecs-agent/wsclient"
"github.com/aws/aws-sdk-go-v2/aws"
)
const (
// heartbeatTimeout is the maximum time to wait between heartbeats
// without disconnecting.
heartbeatTimeout = 1 * time.Minute
heartbeatJitter = 1 * time.Minute
// wsRWTimeout is the duration of read and write deadline for the
// websocket connection.
wsRWTimeout = 2*heartbeatTimeout + heartbeatJitter
inactiveInstanceReconnectDelay = 1 * time.Hour
connectionBackoffMin = 250 * time.Millisecond
connectionBackoffMax = 2 * time.Minute
connectionBackoffJitter = 0.2
connectionBackoffMultiplier = 1.5
inactiveInstanceExceptionPrefix = "InactiveInstanceException"
// ACS protocol version spec:
// 1: default protocol version
// 2: ACS will proactively close the connection when heartbeat ACKs are missing
acsProtocolVersion = 2
)
// Session defines an interface for Agent's long-lived connection with ACS.
// The Session.Start() method can be used to start processing messages from ACS.
type Session interface {
Start(context.Context) error
GetLastConnectedTime() time.Time
}
// session encapsulates all arguments needed to connect to ACS and to handle messages received by ACS.
type session struct {
containerInstanceARN string
cluster string
credentialsCache *aws.CredentialsCache
ecsClient ecs.ECSClient
inactiveInstanceCB func()
agentVersion string
agentHash string
dockerVersion string
payloadMessageHandler PayloadMessageHandler
credentialsManager rolecredentials.Manager
credentialsMetadataSetter CredentialsMetadataSetter
doctor *doctor.Doctor
eniHandler ENIHandler
manifestMessageIDAccessor ManifestMessageIDAccessor
taskComparer TaskComparer
sequenceNumberAccessor SequenceNumberAccessor
taskStopper TaskStopper
resourceHandler ResourceHandler
backoff retry.Backoff
sendCredentials bool
clientFactory wsclient.ClientFactory
metricsFactory metrics.EntryFactory
minAgentConfig *wsclient.WSClientMinAgentConfig
addUpdateRequestHandlers func(wsclient.ClientServer)
heartbeatTimeout time.Duration
heartbeatJitter time.Duration
disconnectTimeout time.Duration
disconnectJitter time.Duration
inactiveInstanceReconnectDelay time.Duration
lastConnectedTime time.Time
firstACSConnectionTime time.Time
}
// NewSession creates a new Session.
func NewSession(containerInstanceARN string,
cluster string,
ecsClient ecs.ECSClient,
credentialsCache *aws.CredentialsCache,
inactiveInstanceCB func(),
clientFactory wsclient.ClientFactory,
metricsFactory metrics.EntryFactory,
agentVersion string,
agentHash string,
dockerVersion string,
minAgentConfig *wsclient.WSClientMinAgentConfig,
payloadMessageHandler PayloadMessageHandler,
credentialsManager rolecredentials.Manager,
credentialsMetadataSetter CredentialsMetadataSetter,
doctor *doctor.Doctor,
eniHandler ENIHandler,
manifestMessageIDAccessor ManifestMessageIDAccessor,
taskComparer TaskComparer,
sequenceNumberAccessor SequenceNumberAccessor,
taskStopper TaskStopper,
resourceHandler ResourceHandler,
addUpdateRequestHandlers func(wsclient.ClientServer),
) Session {
backoff := retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax,
connectionBackoffJitter, connectionBackoffMultiplier)
return &session{
containerInstanceARN: containerInstanceARN,
cluster: cluster,
ecsClient: ecsClient,
credentialsCache: credentialsCache,
inactiveInstanceCB: inactiveInstanceCB,
clientFactory: clientFactory,
metricsFactory: metricsFactory,
agentVersion: agentVersion,
agentHash: agentHash,
dockerVersion: dockerVersion,
minAgentConfig: minAgentConfig,
payloadMessageHandler: payloadMessageHandler,
credentialsManager: credentialsManager,
credentialsMetadataSetter: credentialsMetadataSetter,
doctor: doctor,
eniHandler: eniHandler,
manifestMessageIDAccessor: manifestMessageIDAccessor,
taskComparer: taskComparer,
sequenceNumberAccessor: sequenceNumberAccessor,
taskStopper: taskStopper,
resourceHandler: resourceHandler,
addUpdateRequestHandlers: addUpdateRequestHandlers,
backoff: backoff,
sendCredentials: true,
heartbeatTimeout: heartbeatTimeout,
heartbeatJitter: heartbeatJitter,
disconnectTimeout: wsclient.DisconnectTimeout,
disconnectJitter: wsclient.DisconnectJitterMax,
inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay,
lastConnectedTime: time.Time{},
firstACSConnectionTime: time.Time{},
}
}
// Start starts the session. It'll forever keep trying to connect to ACS unless
// the context is closed.
//
// If the context is closed, Start() would return with the error code returned
// by the context.
func (s *session) Start(ctx context.Context) error {
// connectToACS channel is used to indicate the intent to connect to ACS
// It's processed by the select loop to connect to ACS.
connectToACS := make(chan struct{})
// The below is required to trigger the first connection to ACS.
sendEmptyMessageOnChannel(connectToACS)
// Loop continuously until context is closed/canceled.
for {
select {
case <-connectToACS:
logger.Debug("Received connect to ACS message. Attempting connect to ACS",
logger.Fields{
"containerInstanceARN": s.containerInstanceARN,
})
// Start a session with ACS.
acsError := s.startSessionOnce(ctx)
// Session with ACS was stopped with some error, start processing the error.
reconnectDelay, ok := s.reconnectDelay(acsError)
if ok {
logger.Info("Waiting before reconnecting to ACS", logger.Fields{
"reconnectDelay": reconnectDelay.String(),
"containerInstanceARN": s.containerInstanceARN,
})
waitComplete := waitForDuration(ctx, reconnectDelay)
if waitComplete {
// If the context was not canceled and we've waited for the
// wait duration without any errors, send the message to the channel
// to reconnect to ACS.
logger.Info("Done waiting; reconnecting to ACS",
logger.Fields{
"containerInstanceARN": s.containerInstanceARN,
})
sendEmptyMessageOnChannel(connectToACS)
} else {
// Wait was interrupted. We expect the session to close as canceling
// the session context is the only way to end up here. Print a message
// to indicate the same.
logger.Info("Interrupted waiting for reconnect delay to elapse; Expect session to close",
logger.Fields{
"containerInstanceARN": s.containerInstanceARN,
})
}
} else {
// No need to delay reconnect - reconnect immediately.
logger.Info("Reconnecting to ACS immediately without waiting",
logger.Fields{
"containerInstanceARN": s.containerInstanceARN,
})
sendEmptyMessageOnChannel(connectToACS)
}
case <-ctx.Done():
logger.Info("ACS session ended (context closed)", logger.Fields{
field.Reason: ctx.Err(),
"containerInstanceARN": s.containerInstanceARN,
})
return ctx.Err()
}
}
}
// startSessionOnce creates a session with ACS and handles requests using the passed
// in arguments.
func (s *session) startSessionOnce(ctx context.Context) error {
acsEndpoint, err := s.ecsClient.DiscoverPollEndpoint(s.containerInstanceARN)
if err != nil {
logger.Error("ACS: Unable to discover poll endpoint", logger.Fields{
"containerInstanceARN": s.containerInstanceARN,
field.Error: err,
})
return err
}
client := s.clientFactory.New(
s.acsURL(acsEndpoint),
s.credentialsCache,
wsRWTimeout,
s.minAgentConfig,
s.metricsFactory)
defer client.Close()
// Invoke Connect method as soon as we create client. This will ensure all the
// request handlers to be associated with this client have a valid connection.
acsConnectionStartTime := time.Now()
disconnectTimer, err := client.Connect(metrics.ACSDisconnectTimeoutMetricName, s.disconnectTimeout,
s.disconnectJitter)
// Metric created for determining whether ACS connection is successful or not
s.metricsFactory.New(metrics.ACSSessionFailureCallName).Done(err)
if err != nil {
logger.Error("Failed to connect to ACS", logger.Fields{
"containerInstanceARN": s.containerInstanceARN,
field.Error: err,
})
return err
}
s.metricsFactory.New(metrics.ACSSessionCallDurationName).WithGauge(time.Since(acsConnectionStartTime).Milliseconds()).Done(nil)
defer disconnectTimer.Stop()
if s.GetFirstACSConnectionTime().IsZero() {
s.firstACSConnectionTime = time.Now()
}
// Record the timestamp of the last connection to ACS.
s.lastConnectedTime = time.Now()
// Connection to ACS was successful. Moving forward, rely on ACS to send credentials to Agent at its own cadence
// and make sure Agent does not force ACS to send credentials for any subsequent reconnects to ACS.
logger.Info("Connected to ACS endpoint",
logger.Fields{
"containerInstanceARN": s.containerInstanceARN,
"lastConnectedTime": s.lastConnectedTime,
})
s.sendCredentials = false
return s.startACSSession(ctx, client)
}
// startACSSession starts a session with ACS. It adds request handlers for various
// kinds of messages expected from ACS. It returns on server disconnection or when
// the context is canceled.
func (s *session) startACSSession(ctx context.Context, client wsclient.ClientServer) error {
ctx, cancel := context.WithCancel(ctx)
defer func() {
logger.Info("Current ACS session is done and a new one might be created.",
logger.Fields{
"containerInstanceARN": s.containerInstanceARN,
})
cancel()
}()
responseSender := func(response interface{}) error {
return client.MakeRequest(response)
}
responders := []wsclient.RequestResponder{
NewPayloadResponder(s.payloadMessageHandler, responseSender),
NewRefreshCredentialsResponder(s.credentialsManager, s.credentialsMetadataSetter, s.metricsFactory,
responseSender),
NewAttachTaskENIResponder(s.eniHandler, responseSender),
NewAttachInstanceENIResponder(s.eniHandler, responseSender),
NewHeartbeatResponder(s.doctor, responseSender),
NewTaskManifestResponder(s.taskComparer, s.sequenceNumberAccessor, s.manifestMessageIDAccessor,
s.metricsFactory, responseSender),
NewTaskStopVerificationACKResponder(s.taskStopper, s.manifestMessageIDAccessor, s.metricsFactory),
}
for _, r := range responders {
client.AddRequestHandler(r.HandlerFunc())
}
if s.resourceHandler != nil {
client.AddRequestHandler(NewAttachResourceResponder(s.resourceHandler, s.metricsFactory,
responseSender).HandlerFunc())
}
if s.dockerVersion != "containerd" && s.addUpdateRequestHandlers != nil {
s.addUpdateRequestHandlers(client)
}
// Start a heartbeat timer for closing the connection.
heartbeatTimer := newHeartbeatTimer(client, s.heartbeatTimeout, s.heartbeatJitter)
// Any message from the server resets the heartbeat timer.
client.SetAnyRequestHandler(anyMessageHandler(heartbeatTimer, client))
defer heartbeatTimer.Stop()
backoffResetTimer := time.AfterFunc(
retry.AddJitter(s.heartbeatTimeout, s.heartbeatJitter), func() {
// If we do not have an error connecting and remain connected for at
// least 1 or so minutes, reset the backoff. This prevents disconnect
// errors that only happen infrequently from damaging the reconnect
// delay as significantly.
s.backoff.Reset()
})
defer backoffResetTimer.Stop()
return client.Serve(ctx)
}
func (s *session) reconnectDelay(acsError error) (time.Duration, bool) {
if isInactiveInstanceError(acsError) {
logger.Info("Container instance is deregistered",
logger.Fields{
"containerInstanceARN": s.containerInstanceARN,
})
s.inactiveInstanceCB()
return s.inactiveInstanceReconnectDelay, true
}
if shouldReconnectWithoutBackoff(acsError) {
// ACS has closed the connection for valid reasons. Example: periodic disconnect.
// No need to wait/backoff to reconnect.
logger.Info("ACS WebSocket connection closed for a valid reason",
logger.Fields{
"containerInstanceARN": s.containerInstanceARN,
})
s.backoff.Reset()
return 0, false
}
logger.Warn("ACS WebSocket connection closed",
logger.Fields{
"containerInstanceARN": s.containerInstanceARN,
field.Error: acsError,
})
// Disconnected unexpectedly from ACS, compute backoff duration to reconnect.
return s.backoff.Duration(), true
}
// acsURL returns the websocket url for ACS given the endpoint.
func (s *session) acsURL(endpoint string) string {
wsURL := endpoint
if endpoint[len(endpoint)-1] != '/' {
wsURL += "/"
}
wsURL += "ws"
query := url.Values{}
query.Set("clusterArn", s.cluster)
query.Set("containerInstanceArn", s.containerInstanceARN)
query.Set("agentHash", s.agentHash)
query.Set("agentVersion", s.agentVersion)
query.Set("seqNum", "1")
query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion))
if s.dockerVersion != "" {
query.Set("dockerVersion", formatDockerVersion(s.dockerVersion))
}
// Below indicates if ACS should send credentials for all tasks upon establishing the connection.
query.Set("sendCredentials", strconv.FormatBool(s.sendCredentials))
return wsURL + "?" + query.Encode()
}
// responseToACSSender returns a wsclient.RespondFunc that a responder can invoke in response to receiving and
// processing specific websocket request messages from ACS. The returned wsclient.RespondFunc:
// 1. logs the response to be sent, as well as the name of the invoking responder
// 2. sends the response request to ACS
func responseToACSSender(responderName string, responseSender wsclient.RespondFunc) wsclient.RespondFunc {
return func(response interface{}) error {
logger.Debug("Sending response to ACS", logger.Fields{
"Name": responderName,
"Response": response,
})
return responseSender(response)
}
}
// newHeartbeatTimer creates a new time object, with a callback to
// disconnect from ACS on inactivity (i.e., after timeout + jitter).
func newHeartbeatTimer(client wsclient.ClientServer, timeout time.Duration, jitter time.Duration) ttime.Timer {
timer := time.AfterFunc(retry.AddJitter(timeout, jitter), func() {
logger.Warn("ACS Connection hasn't had any activity for too long; closing connection")
if err := client.Close(); err != nil {
logger.Warn("Error disconnecting from ACS", logger.Fields{
field.Error: err,
})
}
logger.Info("Disconnected from ACS")
})
return timer
}
// anyMessageHandler handles any server message. Any server message means the
// connection is active and thus the heartbeat disconnect should not occur.
func anyMessageHandler(timer ttime.Timer, client wsclient.ClientServer) func(interface{}) {
return func(interface{}) {
logger.Debug("ACS activity occurred")
// Reset read deadline as there's activity on the channel.
if err := client.SetReadDeadline(time.Now().Add(wsRWTimeout)); err != nil {
logger.Warn("Unable to extend read deadline for ACS connection", logger.Fields{
field.Error: err,
})
}
// Reset heartbeat timer.
timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter))
}
}
// waitForDuration waits for the specified duration of time. It returns true if the wait time has completed.
// Else, it returns false.
func waitForDuration(ctx context.Context, duration time.Duration) bool {
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(duration))
defer cancel()
<-ctx.Done()
err := ctx.Err()
return err == context.DeadlineExceeded
}
// sendEmptyMessageOnChannel sends an empty message using a goroutine on the
// specified channel.
func sendEmptyMessageOnChannel(channel chan<- struct{}) {
go func() {
channel <- struct{}{}
}()
}
func shouldReconnectWithoutBackoff(acsError error) bool {
return acsError == nil || acsError == io.EOF
}
func isInactiveInstanceError(acsError error) bool {
return acsError != nil && strings.Contains(acsError.Error(), inactiveInstanceExceptionPrefix)
}
func formatDockerVersion(dockerVersionValue string) string {
if dockerVersionValue != "containerd" {
return "DockerVersion: " + dockerVersionValue
}
return dockerVersionValue
}
// GetLastConnectedTime returns the timestamp that the last connection was established to ACS.
func (s *session) GetLastConnectedTime() time.Time {
return s.lastConnectedTime
}
func (s *session) GetFirstACSConnectionTime() time.Time {
return s.firstACSConnectionTime
}