agent/acs/session/payload_responder.go (268 lines of code) (raw):
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.
package session
import (
"fmt"
"github.com/aws/amazon-ecs-agent/agent/api"
apitask "github.com/aws/amazon-ecs-agent/agent/api/task"
"github.com/aws/amazon-ecs-agent/agent/data"
"github.com/aws/amazon-ecs-agent/agent/engine"
"github.com/aws/amazon-ecs-agent/agent/eventhandler"
"github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs"
apiresource "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment/resource"
"github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs"
apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status"
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger"
loggerfield "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field"
nlappmesh "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/appmesh"
ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/pkg/errors"
)
// skipAddTaskComparatorFunc defines the function pointer that accepts task status
// and returns the boolean comparison result.
type skipAddTaskComparatorFunc func(apitaskstatus.TaskStatus) bool
// payloadMessageHandler implements PayloadMessageHandler interface defined in ecs-agent module.
type payloadMessageHandler struct {
taskEngine engine.TaskEngine
ecsClient ecs.ECSClient
dataClient data.Client
taskHandler *eventhandler.TaskHandler
credentialsManager credentials.Manager
latestSeqNumberTaskManifest *int64
}
// NewPayloadMessageHandler creates a new payloadMessageHandler.
func NewPayloadMessageHandler(taskEngine engine.TaskEngine,
ecsClient ecs.ECSClient,
dataClient data.Client,
taskHandler *eventhandler.TaskHandler,
credentialsManager credentials.Manager,
latestSeqNumberTaskManifest *int64) *payloadMessageHandler {
return &payloadMessageHandler{
taskEngine: taskEngine,
ecsClient: ecsClient,
dataClient: dataClient,
taskHandler: taskHandler,
credentialsManager: credentialsManager,
latestSeqNumberTaskManifest: latestSeqNumberTaskManifest,
}
}
func (pmHandler *payloadMessageHandler) ProcessMessage(message *ecsacs.PayloadMessage,
ackFunc func(*ecsacs.AckRequest, []*ecsacs.IAMRoleCredentialsAckRequest)) error {
credentialsAcks, allTasksHandled := pmHandler.addPayloadTasks(message)
// Update latestSeqNumberTaskManifest for it to get updated in state file.
if pmHandler.latestSeqNumberTaskManifest != nil && message.SeqNum != nil &&
*pmHandler.latestSeqNumberTaskManifest < *message.SeqNum {
*pmHandler.latestSeqNumberTaskManifest = *message.SeqNum
}
if !allTasksHandled {
return errors.Errorf("did not handle all tasks")
}
// Send ACKs - do it in async such that it does not block handling more tasks.
go ackFunc(&ecsacs.AckRequest{
Cluster: message.ClusterArn,
ContainerInstance: message.ContainerInstanceArn,
MessageId: message.MessageId,
}, credentialsAcks)
return nil
}
// addPayloadTasks does validation on each task and, for all valid ones, adds
// it to the task engine. It returns a bool indicating if it could add every
// task to the taskEngine and a slice of credential ack requests.
func (pmHandler *payloadMessageHandler) addPayloadTasks(payload *ecsacs.PayloadMessage) (
[]*ecsacs.IAMRoleCredentialsAckRequest, bool) {
// Verify that we were able to work with all tasks in this payload.
// This is so we know whether to ACK the whole thing or not.
allTasksOK := true
validTasks := make([]*apitask.Task, 0, len(payload.Tasks))
for _, task := range payload.Tasks {
if task == nil {
logger.Critical("Received nil task for message", logger.Fields{
loggerfield.MessageID: aws.ToString(payload.MessageId),
})
allTasksOK = false
continue
}
// Note: If we receive an EBS-backed task, we'll also receive an incomplete volume configuration in the list of Volumes
// To accommodate this, we'll first check if the task IS EBS-backed then we'll mark the corresponding Volume object to be
// of type "attachment". This volume object will be replaced by the newly created EBS volume configuration when we parse
// through the task attachments.
volName, ok := hasEBSAttachment(task)
if ok {
initializeAttachmentTypeVolume(task, volName)
}
apiTask, err := apitask.TaskFromACS(task, payload)
if err != nil {
pmHandler.handleInvalidTask(task, err, payload)
allTasksOK = false
continue
}
logger.Info("Received task payload from ACS", logger.Fields{
loggerfield.TaskARN: apiTask.Arn,
loggerfield.TaskVersion: apiTask.Version,
loggerfield.DesiredStatus: apiTask.GetDesiredStatus().String(),
})
if apiTask.IsFaultInjectionEnabled() {
logger.Info("Fault Injection Enabled for task", logger.Fields{
loggerfield.TaskARN: apiTask.Arn,
})
}
if task.RoleCredentials != nil {
// The payload from ACS for the task has credentials for the
// task. Add those to the credentials manager and set the
// credentials id for the task as well.
taskIAMRoleCredentials := credentials.IAMRoleCredentialsFromACS(task.RoleCredentials,
credentials.ApplicationRoleType)
err = pmHandler.credentialsManager.SetTaskCredentials(
&(credentials.TaskIAMRoleCredentials{
ARN: aws.ToString(task.Arn),
IAMRoleCredentials: taskIAMRoleCredentials,
}))
if err != nil {
pmHandler.handleInvalidTask(task, err, payload)
allTasksOK = false
continue
}
logger.Info("Found application credentials for task", logger.Fields{
loggerfield.TaskARN: apiTask.Arn,
loggerfield.TaskVersion: apiTask.Version,
loggerfield.RoleARN: taskIAMRoleCredentials.RoleArn,
loggerfield.RoleType: taskIAMRoleCredentials.RoleType,
loggerfield.CredentialsID: taskIAMRoleCredentials.CredentialsID,
})
apiTask.SetCredentialsID(taskIAMRoleCredentials.CredentialsID)
}
// Add ENI information to the task struct.
for _, acsENI := range task.ElasticNetworkInterfaces {
eni, err := ni.InterfaceFromACS(acsENI)
if err != nil {
pmHandler.handleInvalidTask(task, err, payload)
allTasksOK = false
continue
}
apiTask.AddTaskENI(eni)
}
// Add the app mesh information to task struct.
if task.ProxyConfiguration != nil {
appmesh, err := nlappmesh.AppMeshFromACS(task.ProxyConfiguration)
if err != nil {
pmHandler.handleInvalidTask(task, err, payload)
allTasksOK = false
continue
}
apiTask.SetAppMesh(appmesh)
}
if task.ExecutionRoleCredentials != nil {
// The payload message contains execution credentials for the task.
// Add the credentials to the credentials manager and set the
// task executionCredentials id.
taskExecutionIAMRoleCredentials := credentials.IAMRoleCredentialsFromACS(task.ExecutionRoleCredentials,
credentials.ExecutionRoleType)
err = pmHandler.credentialsManager.SetTaskCredentials(
&(credentials.TaskIAMRoleCredentials{
ARN: aws.ToString(task.Arn),
IAMRoleCredentials: taskExecutionIAMRoleCredentials,
}))
if err != nil {
pmHandler.handleInvalidTask(task, err, payload)
allTasksOK = false
continue
}
logger.Info("Found execution credentials for task", logger.Fields{
loggerfield.TaskARN: apiTask.Arn,
loggerfield.TaskVersion: apiTask.Version,
loggerfield.RoleARN: taskExecutionIAMRoleCredentials.RoleArn,
loggerfield.RoleType: taskExecutionIAMRoleCredentials.RoleType,
loggerfield.CredentialsID: taskExecutionIAMRoleCredentials.CredentialsID,
})
apiTask.SetExecutionRoleCredentialsID(taskExecutionIAMRoleCredentials.CredentialsID)
}
validTasks = append(validTasks, apiTask)
}
// Add 'stop' transitions first to allow seqnum ordering to work out
// Because a 'start' sequence number should only be proceeded if all 'stop's
// of the same sequence number have completed, the 'start' events need to be
// added after the 'stop' events are there to block them.
stoppedTasksCredentialsAcks, stoppedTasksAddedOK := pmHandler.addTasks(payload, validTasks, isTaskStatusNotStopped)
newTasksCredentialsAcks, newTasksAddedOK := pmHandler.addTasks(payload, validTasks, isTaskStatusStopped)
if !stoppedTasksAddedOK || !newTasksAddedOK {
allTasksOK = false
}
// Construct a slice with credentials acks from all tasks.
credentialsAcks := append(stoppedTasksCredentialsAcks, newTasksCredentialsAcks...)
return credentialsAcks, allTasksOK
}
// handleInvalidTask handles invalid tasks by sending 'stopped' with
// a suitable reason to the backend.
func (pmHandler *payloadMessageHandler) handleInvalidTask(task *ecsacs.Task, err error,
payload *ecsacs.PayloadMessage) {
logger.Warn("Received unexpected ACS message", logger.Fields{
loggerfield.MessageID: aws.ToString(payload.MessageId),
loggerfield.TaskARN: aws.ToString(task.Arn),
loggerfield.Error: err,
})
if aws.ToString(task.Arn) == "" {
logger.Critical("Received task with no ARN for payload message", logger.Fields{
loggerfield.MessageID: aws.ToString(payload.MessageId),
})
return
}
// Only need to stop the task; it brings down the containers too.
taskEvent := api.TaskStateChange{
TaskARN: *task.Arn,
Status: apitaskstatus.TaskStopped,
Reason: UnrecognizedTaskError{err}.Error(),
// The real task cannot be extracted from payload message, so we send an empty task.
// This is necessary because the task handler will not send an event whose
// Task is nil.
Task: &apitask.Task{},
}
pmHandler.taskHandler.AddStateChangeEvent(taskEvent, pmHandler.ecsClient)
}
// addTasks adds the tasks to the task engine based on the skipAddTask condition.
// This is used to add non-stopped tasks before adding stopped tasks.
func (pmHandler *payloadMessageHandler) addTasks(payload *ecsacs.PayloadMessage, tasks []*apitask.Task,
skipAddTask skipAddTaskComparatorFunc) ([]*ecsacs.IAMRoleCredentialsAckRequest, bool) {
allTasksOK := true
var credentialsAcks []*ecsacs.IAMRoleCredentialsAckRequest
for _, task := range tasks {
if skipAddTask(task.GetDesiredStatus()) {
continue
}
pmHandler.taskEngine.AddTask(task)
// Only need to save task to DB when its desired status is RUNNING (i.e. this is a new task that we are going
// to manage). When its desired status is STOPPED, the task is already in the DB and the desired status change
// will be saved by task manager.
if task.GetDesiredStatus() == apitaskstatus.TaskRunning {
err := pmHandler.dataClient.SaveTask(task)
if err != nil {
logger.Error("Failed to save data for task", logger.Fields{
loggerfield.TaskARN: task.Arn,
loggerfield.Error: err,
})
allTasksOK = false
}
}
ackCredentials := func(id string, description string) {
ack, err := pmHandler.ackCredentials(payload.MessageId, id)
if err != nil {
allTasksOK = false
logger.Error(fmt.Sprintf("Failed to acknowledge %s credentials for task",
description), logger.Fields{
"task": task.String(),
loggerfield.Error: err,
})
return
}
credentialsAcks = append(credentialsAcks, ack)
}
// Generate an ack request for the credentials in the task, if the
// task is associated with an IAM role or the execution role.
taskCredentialsID := task.GetCredentialsID()
if taskCredentialsID != "" {
ackCredentials(taskCredentialsID, "task iam role")
}
taskExecutionCredentialsID := task.GetExecutionCredentialsID()
if taskExecutionCredentialsID != "" {
ackCredentials(taskExecutionCredentialsID, "task execution role")
}
}
return credentialsAcks, allTasksOK
}
func (pmHandler *payloadMessageHandler) ackCredentials(messageID *string, credentialsID string) (
*ecsacs.IAMRoleCredentialsAckRequest, error) {
creds, ok := pmHandler.credentialsManager.GetTaskCredentials(credentialsID)
if !ok {
return nil, errors.Errorf("credentials could not be retrieved")
} else {
return &ecsacs.IAMRoleCredentialsAckRequest{
MessageId: messageID,
Expiration: aws.String(creds.IAMRoleCredentials.Expiration),
CredentialsId: aws.String(creds.IAMRoleCredentials.CredentialsID),
}, nil
}
}
// isTaskStatusStopped returns true if the task status == STOPPED.
func isTaskStatusStopped(status apitaskstatus.TaskStatus) bool {
return status == apitaskstatus.TaskStopped
}
// isTaskStatusNotStopped returns true if the task status != STOPPED.
func isTaskStatusNotStopped(status apitaskstatus.TaskStatus) bool {
return status != apitaskstatus.TaskStopped
}
func hasEBSAttachment(acsTask *ecsacs.Task) (string, bool) {
// TODO: This will only work if there's one EBS volume per task. If we there is a case where we have multi-attach for a task, this needs to be modified
for _, attachment := range acsTask.Attachments {
if *attachment.AttachmentType == apiresource.EBSTaskAttach {
for _, property := range attachment.AttachmentProperties {
if *property.Name == apiresource.VolumeNameKey {
return *property.Value, true
}
}
}
}
return "", false
}
func initializeAttachmentTypeVolume(acsTask *ecsacs.Task, volName string) {
for _, volume := range acsTask.Volumes {
if *volume.Name == volName && volume.Type == nil {
newType := "attachment"
volume.Type = &newType
}
}
}