ecs-agent/api/ecs/client/ecs_client.go (804 lines of code) (raw):

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"). You may // not use this file except in compliance with the License. A copy of the // License is located at // // http://aws.amazon.com/apache2.0/ // // or in the "license" file accompanying this file. This file is distributed // on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either // express or implied. See the License for the specific language governing // permissions and limitations under the License. package ecsclient import ( "context" "errors" "fmt" "net/http" "strings" "sync" "time" apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" apierrors "github.com/aws/amazon-ecs-agent/ecs-agent/api/errors" "github.com/aws/amazon-ecs-agent/ecs-agent/async" "github.com/aws/amazon-ecs-agent/ecs-agent/config" "github.com/aws/amazon-ecs-agent/ecs-agent/ec2" "github.com/aws/amazon-ecs-agent/ecs-agent/httpclient" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" "github.com/aws/amazon-ecs-agent/ecs-agent/utils" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" ecsservice "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/aws/smithy-go" "github.com/docker/docker/pkg/meminfo" ) const ( ecsMaxImageDigestLength = 255 ecsMaxContainerReasonLength = 1024 ecsMaxTaskReasonLength = 1024 ecsMaxRuntimeIDLength = 255 defaultPollEndpointCacheTTL = 12 * time.Hour azAttrName = "ecs.availability-zone" cpuArchAttrName = "ecs.cpu-architecture" osTypeAttrName = "ecs.os-type" osFamilyAttrName = "ecs.os-family" // RoundtripTimeout should only time out after dial and TLS handshake timeouts have elapsed. // Add additional 2 seconds to the sum of these 2 timeouts to be extra sure of this. RoundtripTimeout = httpclient.DefaultDialTimeout + httpclient.DefaultTLSHandshakeTimeout + 2*time.Second // Below constants are used for SetInstanceIdentity retry with exponential backoff. setInstanceIdRetryTimeOut = 30 * time.Second setInstanceIdRetryBackoffMin = 100 * time.Millisecond setInstanceIdRetryBackoffMax = 5 * time.Second setInstanceIdRetryBackoffJitter = 0.2 setInstanceIdRetryBackoffMultiple = 2 // discoverPollEndpointTimeout is the maximum permitted time a single ECSClient.DiscoverPollEndpoint call can take. // The SDK client uses the default retryer which gives a max retry count of 3, we combine this with the timeout for the underlying httpclient's RoundtripTimeout. discoverPollEndpointTimeout = 3 * RoundtripTimeout // Below constants are used for RegisterContainerInstance retry with exponential backoff when receiving non-terminal errors. // To ensure parity in all regions and on all launch types, we should not set any time limit on the RCI timeout. // Thus, setting the max RCI retry timeout allowed to 1 hour, and capping max retry backoff at 192 seconds (3 * 2^6). rciMaxRetryTimeAllowed = 1 * time.Hour rciMinBackoff = 3 * time.Second rciMaxBackoff = 192 * time.Second rciRetryJitter = 0.2 rciRetryMultiple = 2.0 ) var nonRetriableErrors = []smithy.APIError{ new(types.AccessDeniedException), new(types.InvalidParameterException), new(types.ClientException), } // ecsClient implements ECSClient interface. type ecsClient struct { credentialsCache *aws.CredentialsCache configAccessor config.AgentConfigAccessor standardClient ecs.ECSStandardSDK submitStateChangeClient ecs.ECSSubmitStateSDK ec2metadata ec2.EC2MetadataClient httpClient *http.Client pollEndpointCache async.TTLCache pollEndpointLock sync.Mutex isFIPSDetected bool shouldExcludeIPv6PortBinding bool sascCustomRetryBackoff func(func() error) error stscAttachmentCustomRetryBackoff func(func() error) error metricsFactory metrics.EntryFactory rciRetryBackoff *retry.ExponentialBackoff } // NewECSClient creates a new ECSClient interface object. func NewECSClient( credentialsCache *aws.CredentialsCache, configAccessor config.AgentConfigAccessor, ec2MetadataClient ec2.EC2MetadataClient, agentVer string, options ...ECSClientOption) (ecs.ECSClient, error) { client := &ecsClient{ credentialsCache: credentialsCache, configAccessor: configAccessor, ec2metadata: ec2MetadataClient, httpClient: httpclient.New(RoundtripTimeout, configAccessor.AcceptInsecureCert(), agentVer, configAccessor.OSType()), pollEndpointCache: async.NewTTLCache(&async.TTL{Duration: defaultPollEndpointCacheTTL}), } // Apply options to configure/override ECS client values. for _, opt := range options { opt(client) } ecsConfig, err := newECSConfig(client.credentialsCache, configAccessor, client.httpClient, client.isFIPSDetected) if err != nil { return nil, err } if client.standardClient == nil { client.standardClient = ecsservice.NewFromConfig(ecsConfig) } if client.submitStateChangeClient == nil { client.submitStateChangeClient = newSubmitStateChangeClient(ecsConfig) } if client.metricsFactory == nil { client.metricsFactory = metrics.NewNopEntryFactory() } if client.rciRetryBackoff == nil { client.rciRetryBackoff = retry.NewExponentialBackoff(rciMinBackoff, rciMaxBackoff, rciRetryJitter, rciRetryMultiple) } return client, nil } func newECSConfig( credentialsCache *aws.CredentialsCache, configAccessor config.AgentConfigAccessor, httpClient *http.Client, isFIPSEnabled bool, ) (aws.Config, error) { // We should respect the endpoint given (if any) because it could be the Gamma or Zeta endpoint of ECS service which // don't have the corresponding FIPS endpoints. Otherwise, when the host has FIPS enabled, we should tell SDK to // pick the FIPS endpoint. var endpointFn = func(_ *awsconfig.LoadOptions) error { return nil } if configAccessor.APIEndpoint() != "" { endpointFn = awsconfig.WithBaseEndpoint(configAccessor.APIEndpoint()) } else if isFIPSEnabled { endpointFn = awsconfig.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled) } ecsConfig, err := awsconfig.LoadDefaultConfig( context.TODO(), awsconfig.WithHTTPClient(httpClient), awsconfig.WithRegion(configAccessor.AWSRegion()), awsconfig.WithCredentialsProvider(credentialsCache), endpointFn, ) if err != nil { return aws.Config{}, err } return ecsConfig, nil } // CreateCluster creates a cluster from a given name and returns its ARN. func (client *ecsClient) CreateCluster(clusterName string) (string, error) { resp, err := client.standardClient.CreateCluster(context.TODO(), &ecsservice.CreateClusterInput{ClusterName: &clusterName}) if err != nil { logger.Critical("Could not create cluster", logger.Fields{ field.Cluster: clusterName, field.Error: err, }) return "", err } logger.Info("Successfully created a cluster", logger.Fields{ field.Cluster: clusterName, }) return *resp.Cluster.ClusterName, nil } // RegisterContainerInstance calculates the appropriate resources, creates // the default cluster if necessary, and returns the registered // ContainerInstanceARN if successful. Supplying a non-empty container // instance ARN allows a container instance to update its registered // resources. func (client *ecsClient) RegisterContainerInstance(containerInstanceArn string, attributes []types.Attribute, tags []types.Tag, registrationToken string, platformDevices []types.PlatformDevice, outpostARN string) (string, string, error) { clusterRef := client.configAccessor.Cluster() // If our clusterRef is empty, we should try to create the default. if clusterRef == "" { clusterRef = client.configAccessor.DefaultClusterName() defer client.configAccessor.UpdateCluster(clusterRef) // Attempt to register without checking existence of the cluster so that we don't require // excess permissions in the case where the cluster already exists and is active. containerInstanceArn, availabilityzone, err := client.registerContainerInstanceWithRetry(clusterRef, containerInstanceArn, attributes, tags, registrationToken, platformDevices, outpostARN) if err == nil { return containerInstanceArn, availabilityzone, nil } // If trying to register fails because the default cluster doesn't exist, try to create the cluster before // calling register again. if apierrors.IsClusterNotFoundError(err) { clusterRef, err = client.CreateCluster(clusterRef) if err != nil { return "", "", err } } } return client.registerContainerInstanceWithRetry(clusterRef, containerInstanceArn, attributes, tags, registrationToken, platformDevices, outpostARN) } // registerContainerInstanceWithRetry wraps around registerContainerInstance with exponential backoff retry implementation. func (client *ecsClient) registerContainerInstanceWithRetry(clusterRef string, containerInstanceArn string, attributes []types.Attribute, tags []types.Tag, registrationToken string, platformDevices []types.PlatformDevice, outpostARN string) (string, string, error) { var containerInstanceARN, availabilityZone string var errFromRCI error ctx, cancel := context.WithTimeout(context.Background(), rciMaxRetryTimeAllowed) defer cancel() // Reset the backoff such that retries from past calls won't impact the current call. client.rciRetryBackoff.Reset() err := retry.RetryWithBackoffCtx(ctx, client.rciRetryBackoff, func() error { containerInstanceARN, availabilityZone, errFromRCI = client.registerContainerInstance( clusterRef, containerInstanceArn, attributes, tags, registrationToken, platformDevices, outpostARN) if errFromRCI != nil { if !isTransientError(errFromRCI) { logger.Error("Received terminal error from RegisterContainerInstance call, exiting", logger.Fields{ field.Error: errFromRCI, }) // Mark the error as non-retriable, to stop the retry loop in RetryWithBackoffCtx. return apierrors.NewRetriableError(apierrors.NewRetriable(false), errFromRCI) } else { logger.Error("Received non-terminal error from RegisterContainerInstance call, retrying with exponential backoff", logger.Fields{ field.Error: errFromRCI, }) // Mark non-terminal errors as retriable, to continue the retry loop in RetryWithBackoffCtx. return apierrors.NewRetriableError(apierrors.NewRetriable(true), errFromRCI) } } return nil }) if err != nil { // return errFromRCI instead of err returned by the retry wrapper, as err wraps around the original error thrown by RCI. // errFromRCI has implementation to mark the exit code terminal, so that systemd won't restart the agent binary. return "", "", errFromRCI } return containerInstanceARN, availabilityZone, nil } func (client *ecsClient) registerContainerInstance(clusterRef string, containerInstanceArn string, attributes []types.Attribute, tags []types.Tag, registrationToken string, platformDevices []types.PlatformDevice, outpostARN string) (string, string, error) { registerRequest := ecsservice.RegisterContainerInstanceInput{Cluster: &clusterRef} var registrationAttributes []types.Attribute if containerInstanceArn != "" { // We are re-connecting a previously registered instance, restored from snapshot. registerRequest.ContainerInstanceArn = &containerInstanceArn } else { // This is a new instance, not previously registered. // Custom attribute registration only happens on initial instance registration. for _, attribute := range client.getCustomAttributes() { logger.Debug("Added a new custom attribute", logger.Fields{ field.AttributeName: aws.ToString(attribute.Name), field.AttributeValue: aws.ToString(attribute.Value), }) registrationAttributes = append(registrationAttributes, attribute) } } // Standard attributes are included with all registrations. registrationAttributes = append(registrationAttributes, attributes...) // Add additional attributes, such as the OS type. registrationAttributes = append(registrationAttributes, client.getAdditionalAttributes()...) registrationAttributes = append(registrationAttributes, client.getOutpostAttribute(outpostARN)...) registerRequest.Attributes = registrationAttributes if len(tags) > 0 { registerRequest.Tags = tags } registerRequest.PlatformDevices = platformDevices registerRequest = client.setInstanceIdentity(registerRequest) resources, err := client.getResources() if err != nil { return "", "", err } registerRequest.TotalResources = resources registerRequest.ClientToken = &registrationToken resp, err := client.standardClient.RegisterContainerInstance(context.TODO(), &registerRequest) if err != nil { logger.Error("Unable to register as a container instance with ECS", logger.Fields{ field.Error: err, }) return "", "", err } var availabilityzone = "" if resp != nil { for _, attr := range resp.ContainerInstance.Attributes { if aws.ToString(attr.Name) == azAttrName { availabilityzone = aws.ToString(attr.Value) break } } } logger.Info("Registered container instance with cluster!") err = validateRegisteredAttributes(registerRequest.Attributes, resp.ContainerInstance.Attributes) return aws.ToString(resp.ContainerInstance.ContainerInstanceArn), availabilityzone, err } func (client *ecsClient) setInstanceIdentity( registerRequest ecsservice.RegisterContainerInstanceInput) ecsservice.RegisterContainerInstanceInput { instanceIdentityDoc := "" instanceIdentitySignature := "" if client.configAccessor.NoInstanceIdentityDocument() { logger.Info("Fetching Instance ID Document has been disabled") registerRequest.InstanceIdentityDocument = &instanceIdentityDoc registerRequest.InstanceIdentityDocumentSignature = &instanceIdentitySignature return registerRequest } iidRetrieved := true backoff := retry.NewExponentialBackoff(setInstanceIdRetryBackoffMin, setInstanceIdRetryBackoffMax, setInstanceIdRetryBackoffJitter, setInstanceIdRetryBackoffMultiple) ctx, cancel := context.WithTimeout(context.Background(), setInstanceIdRetryTimeOut) defer cancel() err := retry.RetryWithBackoffCtx(ctx, backoff, func() error { var attemptErr error logger.Debug("Attempting to get Instance Identity Document") instanceIdentityDoc, attemptErr = client.ec2metadata.GetDynamicData(ec2.InstanceIdentityDocumentResource) if attemptErr != nil { logger.Error("Unable to get instance identity document, retrying", logger.Fields{ field.Error: attemptErr, }) // Force credentials to expire in case they are stale but not expired. client.credentialsCache.Invalidate() if creds, err := client.credentialsCache.Retrieve(ctx); err != nil || !creds.HasKeys() { logger.Error("Unable to get valid credentials after invalidating credentials cache", logger.Fields{ field.Error: err, }) } return apierrors.NewRetriableError(apierrors.NewRetriable(true), attemptErr) } logger.Debug("Successfully retrieved Instance Identity Document") return nil }) if err != nil { logger.Error("Unable to get instance identity document", logger.Fields{ field.Error: err, }) iidRetrieved = false } registerRequest.InstanceIdentityDocument = &instanceIdentityDoc if iidRetrieved { ctx, cancel = context.WithTimeout(context.Background(), setInstanceIdRetryTimeOut) defer cancel() err = retry.RetryWithBackoffCtx(ctx, backoff, func() error { var attemptErr error logger.Debug("Attempting to get Instance Identity Signature") instanceIdentitySignature, attemptErr = client.ec2metadata. GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource) if attemptErr != nil { logger.Debug("Unable to get instance identity signature, retrying", logger.Fields{ field.Error: attemptErr, }) return apierrors.NewRetriableError(apierrors.NewRetriable(true), attemptErr) } logger.Debug("Successfully retrieved Instance Identity Signature") return nil }) if err != nil { logger.Error("Unable to get instance identity signature", logger.Fields{ field.Error: err, }) } } registerRequest.InstanceIdentityDocumentSignature = &instanceIdentitySignature return registerRequest } func attributesToMap(attributes []types.Attribute) map[string]string { attributeMap := make(map[string]string) attribs := attributes for _, attribute := range attribs { attributeMap[aws.ToString(attribute.Name)] = aws.ToString(attribute.Value) } return attributeMap } func findMissingAttributes(expectedAttributes, actualAttributes map[string]string) ([]string, error) { missingAttributes := make([]string, 0) var err error for key, val := range expectedAttributes { if actualAttributes[key] != val { missingAttributes = append(missingAttributes, key) } else { logger.Trace("Response contained expected value for attribute", logger.Fields{ "key": key, }) } } if len(missingAttributes) > 0 { err = apierrors.NewAttributeError("Attribute validation failed") } return missingAttributes, err } func (client *ecsClient) getResources() ([]types.Resource, error) { // Below are micro-optimizations - the pointers to integerStr and stringSetStr are used multiple times below. integerStr := "INTEGER" stringSetStr := "STRINGSET" cpu, mem := getCpuAndMemory() remainingMem := mem - int32(client.configAccessor.ReservedMemory()) logger.Info("Remaining memory", logger.Fields{ "remainingMemory": remainingMem, }) if remainingMem < 0 { return nil, fmt.Errorf( "api register-container-instance: reserved memory is higher than available memory on the host, "+ "total memory: %d, reserved: %d", mem, client.configAccessor.ReservedMemory()) } cpuResource := types.Resource{ Name: aws.String("CPU"), Type: &integerStr, IntegerValue: cpu, } memResource := types.Resource{ Name: aws.String("MEMORY"), Type: &integerStr, IntegerValue: remainingMem, } portResource := types.Resource{ Name: aws.String("PORTS"), Type: &stringSetStr, StringSetValue: utils.Uint16SliceToStringSlice(client.configAccessor.ReservedPorts()), } udpPortResource := types.Resource{ Name: aws.String("PORTS_UDP"), Type: &stringSetStr, StringSetValue: utils.Uint16SliceToStringSlice(client.configAccessor.ReservedPortsUDP()), } return []types.Resource{cpuResource, memResource, portResource, udpPortResource}, nil } // GetHostResources calling getHostResources to get a list of CPU, MEMORY, PORTS and PORTS_UPD resources // and return a resourceMap that map the resource name to each resource func (client *ecsClient) GetHostResources() (map[string]types.Resource, error) { resources, err := client.getResources() if err != nil { return nil, err } resourceMap := make(map[string]types.Resource) for _, resource := range resources { if *resource.Name == "PORTS" { // Except for RCI, TCP Ports are named as PORTS_TCP in Agent for Host Resources purpose. resource.Name = aws.String("PORTS_TCP") } resourceMap[*resource.Name] = resource } return resourceMap, nil } func getCpuAndMemory() (int32, int32) { memInfo, err := meminfo.Read() mem := int32(0) if err == nil { mem = int32(memInfo.MemTotal / 1024 / 1024) // MiB } else { logger.Error("Unable to get memory info", logger.Fields{ field.Error: err, }) } cpu := utils.GetNumCPU() * 1024 return int32(cpu), mem } func validateRegisteredAttributes(expectedAttributes, actualAttributes []types.Attribute) error { var err error expectedAttributesMap := attributesToMap(expectedAttributes) actualAttributesMap := attributesToMap(actualAttributes) missingAttributes, err := findMissingAttributes(expectedAttributesMap, actualAttributesMap) if err != nil { msg := strings.Join(missingAttributes, ",") logger.Error("Error registering attributes", logger.Fields{ field.Error: err, "missingAttributes": msg, }) } return err } func (client *ecsClient) getAdditionalAttributes() []types.Attribute { var attrs []types.Attribute // Add a check to ensure only non-empty values are added // to API call. if client.configAccessor.OSType() != "" { attrs = append(attrs, types.Attribute{ Name: aws.String(osTypeAttrName), Value: aws.String(client.configAccessor.OSType()), }) } // OSFamily should be treated as an optional field as it is not applicable for all agents // using ecs client shared library. Add a check to ensure only non-empty values are added // to API call. if client.configAccessor.OSFamily() != "" { attrs = append(attrs, types.Attribute{ Name: aws.String(osFamilyAttrName), Value: aws.String(client.configAccessor.OSFamily()), }) } // Send CPU arch attribute directly when running on external capacity. When running on EC2 or Fargate launch type, // this is not needed since the CPU arch is reported via instance identity document in those cases. if client.configAccessor.External() { attrs = append(attrs, types.Attribute{ Name: aws.String(cpuArchAttrName), Value: aws.String(getCPUArch()), }) } return attrs } func (client *ecsClient) getOutpostAttribute(outpostARN string) []types.Attribute { if len(outpostARN) > 0 { return []types.Attribute{ { Name: aws.String("ecs.outpost-arn"), Value: aws.String(outpostARN), }, } } return []types.Attribute{} } func (client *ecsClient) getCustomAttributes() []types.Attribute { var attributes []types.Attribute for attribute, value := range client.configAccessor.InstanceAttributes() { attributes = append(attributes, types.Attribute{ Name: aws.String(attribute), Value: aws.String(value), }) } return attributes } func (client *ecsClient) SubmitTaskStateChange(change ecs.TaskStateChange) error { if change.Attachment != nil && client.stscAttachmentCustomRetryBackoff != nil { retryFunc := func() error { err := client.submitTaskStateChange(change) if err == nil { return nil } return submitStateCustomRetriableError(err) } return client.stscAttachmentCustomRetryBackoff(retryFunc) } return client.submitTaskStateChange(change) } func (client *ecsClient) submitTaskStateChange(change ecs.TaskStateChange) error { clusterARN := client.configAccessor.Cluster() if len(change.ClusterARN) != 0 { clusterARN = change.ClusterARN } if change.Attachment != nil { // Confirm attachment by submitting attachment state change via SubmitTaskStateChange API (specifically in // the input's Attachments field). var attachments []types.AttachmentStateChange eniStatus := change.Attachment.Status.String() attachments = []types.AttachmentStateChange{ { AttachmentArn: aws.String(change.Attachment.AttachmentARN), Status: aws.String(eniStatus), }, } _, err := client.submitStateChangeClient.SubmitTaskStateChange(context.TODO(), &ecsservice.SubmitTaskStateChangeInput{ Cluster: aws.String(clusterARN), Task: aws.String(change.TaskARN), Attachments: attachments, }) if err != nil { logger.Warn("Could not submit task state change associated with confirming attachment", logger.Fields{ field.Error: err, "attachmentARN": change.Attachment.AttachmentARN, field.Status: eniStatus, }) return err } return nil } req := ecsservice.SubmitTaskStateChangeInput{ Cluster: aws.String(clusterARN), Task: aws.String(change.TaskARN), Status: aws.String(change.Status.BackendStatus()), Reason: aws.String(trimString(change.Reason, ecsMaxTaskReasonLength)), PullStartedAt: change.PullStartedAt, PullStoppedAt: change.PullStoppedAt, ExecutionStoppedAt: change.ExecutionStoppedAt, ManagedAgents: formatManagedAgents(change.ManagedAgents), Containers: formatContainers(change.Containers, client.shouldExcludeIPv6PortBinding, change.TaskARN), } _, err := client.submitStateChangeClient.SubmitTaskStateChange(context.TODO(), &req) if err != nil { logger.Error("Could not submit task state change", logger.Fields{ field.Error: err, "taskStateChange": change.String(), }) return err } return nil } func (client *ecsClient) SubmitContainerStateChange(change ecs.ContainerStateChange) error { input := ecsservice.SubmitContainerStateChangeInput{ Cluster: aws.String(client.configAccessor.Cluster()), ContainerName: aws.String(change.ContainerName), Task: aws.String(change.TaskArn), } if change.RuntimeID != "" { input.RuntimeId = aws.String(trimString(change.RuntimeID, ecsMaxRuntimeIDLength)) } if change.Reason != "" { input.Reason = aws.String(trimString(change.Reason, ecsMaxContainerReasonLength)) } stat := change.Status.String() if stat == "DEAD" { stat = apicontainerstatus.ContainerStopped.String() } if stat != apicontainerstatus.ContainerStopped.String() && stat != apicontainerstatus.ContainerRunning.String() { logger.Info("Not submitting unsupported upstream container state", logger.Fields{ field.ContainerName: change.ContainerName, field.Status: stat, field.TaskARN: change.TaskArn, }) return nil } input.Status = aws.String(stat) if change.ExitCode != nil { exitCode := int32(aws.ToInt(change.ExitCode)) input.ExitCode = aws.Int32(exitCode) } networkBindings := change.NetworkBindings if client.shouldExcludeIPv6PortBinding { networkBindings = excludeIPv6PortBindingFromNetworkBindings(networkBindings, change.ContainerName, change.TaskArn) } input.NetworkBindings = networkBindings _, err := client.submitStateChangeClient.SubmitContainerStateChange(context.TODO(), &input) if err != nil { logger.Error("Could not submit container state change", logger.Fields{ field.Error: err, field.TaskARN: change.TaskArn, "containerStateChange": change.String(), }) return err } return nil } func (client *ecsClient) SubmitAttachmentStateChange(change ecs.AttachmentStateChange) error { if client.sascCustomRetryBackoff != nil { retryFunc := func() error { err := client.submitAttachmentStateChange(change) if err == nil { return nil } return submitStateCustomRetriableError(err) } return client.sascCustomRetryBackoff(retryFunc) } return client.submitAttachmentStateChange(change) } func (client *ecsClient) submitAttachmentStateChange(change ecs.AttachmentStateChange) error { attachmentStatus := change.Attachment.GetAttachmentStatus() req := ecsservice.SubmitAttachmentStateChangesInput{ Cluster: aws.String(client.configAccessor.Cluster()), Attachments: []types.AttachmentStateChange{ { AttachmentArn: aws.String(change.Attachment.GetAttachmentARN()), Status: aws.String(attachmentStatus.String()), }, }, } _, err := client.submitStateChangeClient.SubmitAttachmentStateChanges(context.TODO(), &req) if err != nil { logger.Warn("Could not submit attachment state change", logger.Fields{ field.Error: err, "attachmentStateChange": change.String(), }) return err } return nil } func submitStateCustomRetriableError(err error) error { retry := true for _, apiErr := range nonRetriableErrors { if errors.As(err, &apiErr) { retry = false break } } return apierrors.NewRetriableError(apierrors.NewRetriable(retry), err) } func (client *ecsClient) DiscoverPollEndpoint(containerInstanceArn string) (string, error) { resp, err := client.discoverPollEndpoint(containerInstanceArn, "") if err != nil { return "", err } if resp.Endpoint == nil { return "", errors.New("no endpoint returned; nil") } return aws.ToString(resp.Endpoint), nil } func (client *ecsClient) DiscoverTelemetryEndpoint(containerInstanceArn string) (string, error) { resp, err := client.discoverPollEndpoint(containerInstanceArn, "") if err != nil { return "", err } if resp.TelemetryEndpoint == nil { return "", errors.New("no telemetry endpoint returned; nil") } return aws.ToString(resp.TelemetryEndpoint), nil } func (client *ecsClient) DiscoverServiceConnectEndpoint(containerInstanceArn string) (string, error) { resp, err := client.discoverPollEndpoint(containerInstanceArn, "") if err != nil { return "", err } if resp.ServiceConnectEndpoint == nil { return "", errors.New("no ServiceConnect endpoint returned; nil") } return aws.ToString(resp.ServiceConnectEndpoint), nil } func (client *ecsClient) DiscoverSystemLogsEndpoint(containerInstanceArn string, availabilityZone string) (string, error) { resp, err := client.discoverPollEndpoint(containerInstanceArn, availabilityZone) if err != nil { return "", err } if resp.SystemLogsEndpoint == nil { return "", errors.New("no system logs endpoint returned; nil") } return aws.ToString(resp.SystemLogsEndpoint), nil } func (client *ecsClient) discoverPollEndpoint(containerInstanceArn string, availabilityZone string) (*ecsservice.DiscoverPollEndpointOutput, error) { client.pollEndpointLock.Lock() defer client.pollEndpointLock.Unlock() ctx, cancel := context.WithTimeout(context.Background(), discoverPollEndpointTimeout) defer cancel() // Try getting an entry from the cache. cachedEndpoint, expired, found := client.pollEndpointCache.Get(containerInstanceArn) if !expired && found { // Cache hit and not expired. Return the output. output, ok := cachedEndpoint.(*ecsservice.DiscoverPollEndpointOutput) systemLogsEndpoint := aws.ToString(output.SystemLogsEndpoint) if ok { // Presence of the system logs endpoint can be disregarded if the AZ was not provided, // but the cache hit must include a non-empty system logs endpoint if the AZ was provided. if availabilityZone == "" || (availabilityZone != "" && systemLogsEndpoint != "") { logger.Info("Using cached DiscoverPollEndpoint", logger.Fields{ field.Endpoint: aws.ToString(output.Endpoint), field.TelemetryEndpoint: aws.ToString(output.TelemetryEndpoint), field.ServiceConnectEndpoint: aws.ToString(output.ServiceConnectEndpoint), field.SystemLogsEndpoint: systemLogsEndpoint, field.ContainerInstanceARN: containerInstanceArn, }) return output, nil } } } discoverPollEndpointStartTime := time.Now() // Cache miss or expired, invoke the ECS DiscoverPollEndpoint API. logger.Debug("Invoking DiscoverPollEndpoint", logger.Fields{ field.ContainerInstanceARN: containerInstanceArn, field.AvailabilityZone: availabilityZone, }) output, err := client.standardClient.DiscoverPollEndpoint(ctx, &ecsservice.DiscoverPollEndpointInput{ ContainerInstance: &containerInstanceArn, Cluster: aws.String(client.configAccessor.Cluster()), ZoneId: aws.String(availabilityZone), }) client.metricsFactory.New(metrics.DiscoverPollEndpointFailure).Done(err) client.metricsFactory.New(metrics.DiscoverPollEndpointTotal).Done(nil) if err != nil { // If we got an error calling the API, fallback to an expired cached endpoint if // we have it. if expired { if output, ok := cachedEndpoint.(*ecsservice.DiscoverPollEndpointOutput); ok { logger.Info("Error calling DiscoverPollEndpoint. Using cached-but-expired endpoint as a fallback.", logger.Fields{ field.Endpoint: aws.ToString(output.Endpoint), field.TelemetryEndpoint: aws.ToString(output.TelemetryEndpoint), field.ServiceConnectEndpoint: aws.ToString(output.ServiceConnectEndpoint), field.SystemLogsEndpoint: aws.ToString(output.SystemLogsEndpoint), field.ContainerInstanceARN: containerInstanceArn, }) return output, nil } } return nil, err } client.metricsFactory.New(metrics.DiscoverPollEndpointDurationName).WithGauge(time.Since(discoverPollEndpointStartTime).Milliseconds()).Done(nil) // Cache the response from ECS. client.pollEndpointCache.Set(containerInstanceArn, output) return output, nil } func (client *ecsClient) GetResourceTags(resourceArn string) ([]types.Tag, error) { output, err := client.standardClient.ListTagsForResource(context.TODO(), &ecsservice.ListTagsForResourceInput{ ResourceArn: &resourceArn, }) if err != nil { return nil, err } return output.Tags, nil } func (client *ecsClient) UpdateContainerInstancesState(instanceARN string, status types.ContainerInstanceStatus) error { logger.Debug("Invoking UpdateContainerInstancesState", logger.Fields{ field.Status: status, field.ContainerInstanceARN: instanceARN, }) _, err := client.standardClient.UpdateContainerInstancesState(context.TODO(), &ecsservice.UpdateContainerInstancesStateInput{ ContainerInstances: []string{instanceARN}, Status: status, Cluster: aws.String(client.configAccessor.Cluster()), }) return err } func formatManagedAgents(managedAgents []types.ManagedAgentStateChange) []types.ManagedAgentStateChange { var result []types.ManagedAgentStateChange for _, m := range managedAgents { if m.Reason != nil { m.Reason = trimStringPtr(m.Reason, ecsMaxContainerReasonLength) } result = append(result, m) } return result } func formatContainers(containers []types.ContainerStateChange, shouldExcludeIPv6PortBinding bool, taskARN string) []types.ContainerStateChange { var result []types.ContainerStateChange for _, c := range containers { if c.RuntimeId != nil { c.RuntimeId = trimStringPtr(c.RuntimeId, ecsMaxRuntimeIDLength) } if c.Reason != nil { c.Reason = trimStringPtr(c.Reason, ecsMaxContainerReasonLength) } if c.ImageDigest != nil { c.ImageDigest = trimStringPtr(c.ImageDigest, ecsMaxImageDigestLength) } if shouldExcludeIPv6PortBinding { c.NetworkBindings = excludeIPv6PortBindingFromNetworkBindings(c.NetworkBindings, aws.ToString(c.ContainerName), taskARN) } result = append(result, c) } return result } func excludeIPv6PortBindingFromNetworkBindings(networkBindings []types.NetworkBinding, containerName, taskARN string) []types.NetworkBinding { var result []types.NetworkBinding for _, binding := range networkBindings { if aws.ToString(binding.BindIP) == "::" { logger.Debug("Exclude IPv6 port binding", logger.Fields{ "portBinding": binding, field.ContainerName: containerName, field.TaskARN: taskARN, }) continue } result = append(result, binding) } return result } func trimStringPtr(inputStringPtr *string, maxLen int) *string { if inputStringPtr == nil { return nil } return aws.String(trimString(aws.ToString(inputStringPtr), maxLen)) } func trimString(inputString string, maxLen int) string { if len(inputString) > maxLen { trimmed := inputString[0:maxLen] return trimmed } else { return inputString } } func isTransientError(err error) bool { var apiErr smithy.APIError // Using errors.As to unwrap as opposed to errors.Is. if errors.As(err, &apiErr) { switch apiErr.ErrorCode() { case apierrors.ErrCodeServerException, apierrors.ErrCodeLimitExceededException: return true } } return false }