internal/pkg/cli/run_local.go (1,049 lines of code) (raw):
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package cli
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
sdkecs "github.com/aws/aws-sdk-go/service/ecs"
"github.com/aws/aws-sdk-go/service/rds"
sdksecretsmanager "github.com/aws/aws-sdk-go/service/secretsmanager"
sdkssm "github.com/aws/aws-sdk-go/service/ssm"
cmdtemplate "github.com/aws/copilot-cli/cmd/copilot/template"
"github.com/aws/copilot-cli/internal/pkg/aws/ecr"
awsecs "github.com/aws/copilot-cli/internal/pkg/aws/ecs"
"github.com/aws/copilot-cli/internal/pkg/aws/identity"
"github.com/aws/copilot-cli/internal/pkg/aws/resourcegroups"
"github.com/aws/copilot-cli/internal/pkg/aws/secretsmanager"
"github.com/aws/copilot-cli/internal/pkg/aws/sessions"
"github.com/aws/copilot-cli/internal/pkg/aws/ssm"
clideploy "github.com/aws/copilot-cli/internal/pkg/cli/deploy"
"github.com/aws/copilot-cli/internal/pkg/cli/file"
"github.com/aws/copilot-cli/internal/pkg/cli/group"
"github.com/aws/copilot-cli/internal/pkg/config"
"github.com/aws/copilot-cli/internal/pkg/deploy"
"github.com/aws/copilot-cli/internal/pkg/deploy/cloudformation"
"github.com/aws/copilot-cli/internal/pkg/describe"
"github.com/aws/copilot-cli/internal/pkg/docker/dockerengine"
"github.com/aws/copilot-cli/internal/pkg/docker/dockerfile"
"github.com/aws/copilot-cli/internal/pkg/docker/orchestrator"
"github.com/aws/copilot-cli/internal/pkg/ecs"
"github.com/aws/copilot-cli/internal/pkg/exec"
"github.com/aws/copilot-cli/internal/pkg/manifest"
"github.com/aws/copilot-cli/internal/pkg/repository"
"github.com/aws/copilot-cli/internal/pkg/template"
termcolor "github.com/aws/copilot-cli/internal/pkg/term/color"
"github.com/aws/copilot-cli/internal/pkg/term/log"
termprogress "github.com/aws/copilot-cli/internal/pkg/term/progress"
"github.com/aws/copilot-cli/internal/pkg/term/prompt"
"github.com/aws/copilot-cli/internal/pkg/term/selector"
"github.com/aws/copilot-cli/internal/pkg/term/syncbuffer"
"github.com/aws/copilot-cli/internal/pkg/workspace"
"github.com/fsnotify/fsnotify"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
const (
workloadAskPrompt = "Which workload would you like to run locally?"
)
const (
// Command to retrieve container credentials with ecs exec. See more at https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html.
// Example output: {"AccessKeyId":"ACCESS_KEY_ID","Expiration":"EXPIRATION_DATE","RoleArn":"TASK_ROLE_ARN","SecretAccessKey":"SECRET_ACCESS_KEY","Token":"SECURITY_TOKEN_STRING"}
curlContainerCredentialsCmd = "curl 169.254.170.2$AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
)
type containerOrchestrator interface {
Start() <-chan error
RunTask(orchestrator.Task, ...orchestrator.RunTaskOption)
Stop()
}
type hostFinder interface {
Hosts(context.Context) ([]orchestrator.Host, error)
}
type taggedResourceGetter interface {
GetResourcesByTags(string, map[string]string) ([]*resourcegroups.Resource, error)
}
type rdsDescriber interface {
DescribeDBInstancesPagesWithContext(context.Context, *rds.DescribeDBInstancesInput, func(*rds.DescribeDBInstancesOutput, bool) bool, ...request.Option) error
DescribeDBClustersPagesWithContext(context.Context, *rds.DescribeDBClustersInput, func(*rds.DescribeDBClustersOutput, bool) bool, ...request.Option) error
}
type recursiveWatcher interface {
Add(path string) error
Close() error
Events() <-chan fsnotify.Event
Errors() <-chan error
}
type runLocalVars struct {
wkldName string
wkldType string
appName string
envName string
envOverrides map[string]string
watch bool
useTaskRole bool
portOverrides portOverrides
proxy bool
proxyNetwork net.IPNet
}
type runLocalOpts struct {
runLocalVars
sel deploySelector
ecsClient ecsClient
ecsExecutor ecsCommandExecutor
ssm secretGetter
secretsManager secretGetter
sessProvider sessionProvider
sess *session.Session
envManagerSess *session.Session
targetEnv *config.Environment
targetApp *config.Application
store store
ws wsWlDirReader
cmd execRunner
dockerEngine dockerEngineRunner
repository repositoryService
prog progress
orchestrator containerOrchestrator
hostFinder hostFinder
envChecker versionCompatibilityChecker
debounceTime time.Duration
dockerExcludes []string
newRecursiveWatcher func() (recursiveWatcher, error)
buildContainerImages func(mft manifest.DynamicWorkload) (map[string]string, error)
configureClients func() error
labeledTermPrinter func(fw syncbuffer.FileWriter, bufs []*syncbuffer.LabeledSyncBuffer, opts ...syncbuffer.LabeledTermPrinterOption) clideploy.LabeledTermPrinter
unmarshal func([]byte) (manifest.DynamicWorkload, error)
newInterpolator func(app, env string) interpolator
captureStdout func() (io.Reader, error)
releaseStdout func()
}
func newRunLocalOpts(vars runLocalVars) (*runLocalOpts, error) {
sessProvider := sessions.ImmutableProvider(sessions.UserAgentExtras("run local"))
defaultSess, err := sessProvider.Default()
if err != nil {
return nil, err
}
store := config.NewSSMStore(identity.New(defaultSess), sdkssm.New(defaultSess), aws.StringValue(defaultSess.Config.Region))
deployStore, err := deploy.NewStore(sessProvider, store)
if err != nil {
return nil, err
}
ws, err := workspace.Use(afero.NewOsFs())
if err != nil {
return nil, err
}
labeledTermPrinter := func(fw syncbuffer.FileWriter, bufs []*syncbuffer.LabeledSyncBuffer, opts ...syncbuffer.LabeledTermPrinterOption) clideploy.LabeledTermPrinter {
return syncbuffer.NewLabeledTermPrinter(fw, bufs, opts...)
}
o := &runLocalOpts{
runLocalVars: vars,
sel: selector.NewDeploySelect(prompt.New(), store, deployStore),
store: store,
ws: ws,
newInterpolator: newManifestInterpolator,
sessProvider: sessProvider,
unmarshal: manifest.UnmarshalWorkload,
sess: defaultSess,
cmd: exec.NewCmd(),
dockerEngine: dockerengine.New(exec.NewCmd()),
labeledTermPrinter: labeledTermPrinter,
prog: termprogress.NewSpinner(log.DiagnosticWriter),
}
o.configureClients = func() error {
defaultSessEnvRegion, err := o.sessProvider.DefaultWithRegion(o.targetEnv.Region)
if err != nil {
return fmt.Errorf("create default session with region %s: %w", o.targetEnv.Region, err)
}
o.envManagerSess, err = o.sessProvider.FromRole(o.targetEnv.ManagerRoleARN, o.targetEnv.Region)
if err != nil {
return fmt.Errorf("create env manager session %s: %w", o.targetEnv.Region, err)
}
// EnvManagerRole has permissions to get task def and get SSM values.
// However, it doesn't have permissions to get secrets from secrets manager,
// so use the default sess and *hope* they have permissions.
o.ecsClient = ecs.New(o.envManagerSess)
o.ssm = ssm.New(o.envManagerSess)
o.ecsExecutor = awsecs.New(o.envManagerSess)
o.secretsManager = secretsmanager.New(defaultSessEnvRegion)
resources, err := cloudformation.New(o.sess, cloudformation.WithProgressTracker(os.Stderr)).GetAppResourcesByRegion(o.targetApp, o.targetEnv.Region)
if err != nil {
return fmt.Errorf("get application %s resources from region %s: %w", o.appName, o.envName, err)
}
repoName := clideploy.RepoName(o.appName, o.wkldName)
o.repository = repository.NewWithURI(ecr.New(defaultSessEnvRegion), repoName, resources.RepositoryURLs[o.wkldName])
idPrefix := fmt.Sprintf("%s-%s-%s-", o.appName, o.envName, o.wkldName)
colorGen := termcolor.ColorGenerator()
o.orchestrator = orchestrator.New(o.dockerEngine, idPrefix, func(name string, ctr orchestrator.ContainerDefinition) dockerengine.RunLogOptions {
return dockerengine.RunLogOptions{
Color: colorGen(),
Output: os.Stderr,
LinePrefix: fmt.Sprintf("[%s] ", name),
}
})
o.hostFinder = &hostDiscoverer{
app: o.appName,
env: o.envName,
wkld: o.wkldName,
ecs: ecs.New(o.envManagerSess),
rg: resourcegroups.New(o.envManagerSess),
rds: rds.New(o.envManagerSess),
}
envDesc, err := describe.NewEnvDescriber(describe.NewEnvDescriberConfig{
App: o.appName,
Env: o.envName,
ConfigStore: store,
})
if err != nil {
return fmt.Errorf("create env describer: %w", err)
}
o.envChecker = envDesc
return nil
}
o.buildContainerImages = func(mft manifest.DynamicWorkload) (map[string]string, error) {
if dockerWkld, ok := mft.Manifest().(dockerWorkload); ok {
dfDir := filepath.Dir(dockerWkld.Dockerfile())
o.dockerExcludes, err = dockerfile.ReadDockerignore(afero.NewOsFs(), filepath.Join(ws.Path(), dfDir))
if err != nil {
return nil, err
}
o.filterDockerExcludes()
}
gitShortCommit := imageTagFromGit(o.cmd)
image := clideploy.ContainerImageIdentifier{
GitShortCommitTag: gitShortCommit,
}
out := &clideploy.UploadArtifactsOutput{}
if err := clideploy.BuildContainerImages(&clideploy.ImageActionInput{
Name: o.wkldName,
WorkspacePath: o.ws.Path(),
Image: image,
Mft: mft.Manifest(),
GitShortCommitTag: gitShortCommit,
Builder: o.repository,
Login: o.repository.Login,
CheckDockerEngine: o.dockerEngine.CheckDockerEngineRunning,
LabeledTermPrinter: o.labeledTermPrinter,
}, out); err != nil {
return nil, err
}
containerURIs := make(map[string]string, len(out.ImageDigests))
for name, info := range out.ImageDigests {
if len(info.RepoTags) == 0 {
// this shouldn't happen, but just to avoid a panic in case
return nil, fmt.Errorf("no repo tags for image %q", name)
}
containerURIs[name] = info.RepoTags[0]
}
return containerURIs, nil
}
o.debounceTime = 5 * time.Second
o.newRecursiveWatcher = func() (recursiveWatcher, error) {
return file.NewRecursiveWatcher(0)
}
// Capture stdout by replacing it with a piped writer and returning an attached io.Reader.
// Functions are concurrency safe and idempotent.
var mu sync.Mutex
var savedWriter, savedStdout *os.File
savedStdout = os.Stdout
o.captureStdout = func() (io.Reader, error) {
if savedWriter != nil {
savedWriter.Close()
}
pipeReader, pipeWriter, err := os.Pipe()
if err != nil {
return nil, err
}
mu.Lock()
defer mu.Unlock()
savedWriter = pipeWriter
os.Stdout = savedWriter
return (io.Reader)(pipeReader), nil
}
o.releaseStdout = func() {
mu.Lock()
defer mu.Unlock()
os.Stdout = savedStdout
savedWriter.Close()
}
return o, nil
}
// Validate returns an error for any invalid optional flags.
func (o *runLocalOpts) Validate() error {
if o.appName == "" {
return errNoAppInWorkspace
}
// Ensure that the application name provided exists in the workspace
app, err := o.store.GetApplication(o.appName)
if err != nil {
return fmt.Errorf("get application %s: %w", o.appName, err)
}
o.targetApp = app
return nil
}
// Ask prompts the user for any unprovided required fields and validates them.
func (o *runLocalOpts) Ask() error {
return o.validateAndAskWkldEnvName()
}
func (o *runLocalOpts) validateAndAskWkldEnvName() error {
if o.envName != "" {
env, err := o.store.GetEnvironment(o.appName, o.envName)
if err != nil {
return err
}
o.targetEnv = env
}
if o.wkldName != "" {
if _, err := o.store.GetWorkload(o.appName, o.wkldName); err != nil {
return err
}
}
deployedWorkload, err := o.sel.DeployedWorkload(workloadAskPrompt, "", o.appName, selector.WithEnv(o.envName), selector.WithName(o.wkldName))
if err != nil {
return fmt.Errorf("select a deployed workload from application %s: %w", o.appName, err)
}
if o.envName == "" {
env, err := o.store.GetEnvironment(o.appName, deployedWorkload.Env)
if err != nil {
return fmt.Errorf("get environment %q configuration: %w", o.envName, err)
}
o.targetEnv = env
}
o.wkldName = deployedWorkload.Name
o.envName = deployedWorkload.Env
o.wkldType = deployedWorkload.Type
return nil
}
// Execute builds and runs the workload images locally.
func (o *runLocalOpts) Execute() error {
if err := o.configureClients(); err != nil {
return err
}
ctx := context.Background()
task, err := o.prepareTask(ctx)
if err != nil {
return err
}
var hosts []orchestrator.Host
var ssmTarget string
if o.proxy {
if err := validateMinEnvVersion(o.ws, o.envChecker, o.appName, o.envName, template.RunLocalProxyMinEnvVersion, "run local --proxy"); err != nil {
return err
}
hosts, err = o.hostFinder.Hosts(ctx)
if err != nil {
return fmt.Errorf("find hosts to connect to: %w", err)
}
ssmTarget, err = o.getSSMTarget(ctx)
if err != nil {
return fmt.Errorf("get proxy target container: %w", err)
}
}
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
errCh := o.orchestrator.Start()
var runTaskOpts []orchestrator.RunTaskOption
if o.proxy {
runTaskOpts = append(runTaskOpts, orchestrator.RunTaskWithProxy(ssmTarget, o.proxyNetwork, hosts...))
}
o.orchestrator.RunTask(task, runTaskOpts...)
var watchCh <-chan interface{}
var watchErrCh <-chan error
stopCh := make(chan struct{})
if o.watch {
watchCh, watchErrCh, err = o.watchLocalFiles(stopCh)
if err != nil {
return fmt.Errorf("setup watch: %s", err)
}
}
for {
select {
case err, ok := <-errCh:
// we loop until errCh closes, since Start()
// closes errCh when the orchestrator is completely done.
if !ok {
close(stopCh)
return nil
}
log.Errorf("error: %s\n", err)
o.orchestrator.Stop()
case <-sigCh:
signal.Stop(sigCh)
o.orchestrator.Stop()
case <-watchErrCh:
log.Errorf("watch: %s\n", err)
o.orchestrator.Stop()
case <-watchCh:
task, err = o.prepareTask(ctx)
if err != nil {
log.Errorf("rerun task: %s\n", err)
o.orchestrator.Stop()
break
}
// If TaskRole is retrieved through ECS Exec, OS signals are no longer provided to the channel.
// We reset this channel connection through this call as a short term fix that allows
// the interrupt and terminate signal to stop tasks after the task has been restarted by --watch.
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
o.orchestrator.RunTask(task)
}
}
}
// getSSMTarget returns a AWS SSM target for a running container
// that supports ECS Service Exec.
func (o *runLocalOpts) getSSMTarget(ctx context.Context) (string, error) {
svc, err := o.ecsClient.DescribeService(o.appName, o.envName, o.wkldName)
if err != nil {
return "", fmt.Errorf("describe service: %w", err)
}
for _, task := range svc.Tasks {
// TaskArn should have the format: arn:aws:ecs:us-west-2:123456789:task/clusterName/taskName
taskARN, err := arn.Parse(aws.StringValue(task.TaskArn))
if err != nil {
return "", fmt.Errorf("parse task arn: %w", err)
}
split := strings.Split(taskARN.Resource, "/")
if len(split) != 3 {
return "", fmt.Errorf("task ARN in unexpected format: %q", taskARN)
}
taskName := split[2]
for _, ctr := range task.Containers {
id := aws.StringValue(ctr.RuntimeId)
hasECSExec := slices.ContainsFunc(ctr.ManagedAgents, func(a *sdkecs.ManagedAgent) bool {
return aws.StringValue(a.Name) == "ExecuteCommandAgent" && aws.StringValue(a.LastStatus) == "RUNNING"
})
if id != "" && hasECSExec && aws.StringValue(ctr.LastStatus) == "RUNNING" {
return fmt.Sprintf("ecs:%s_%s_%s", svc.ClusterName, taskName, aws.StringValue(ctr.RuntimeId)), nil
}
}
}
return "", errors.New("no running tasks have running containers with ecs exec enabled")
}
func (o *runLocalOpts) getTask(ctx context.Context) (orchestrator.Task, error) {
td, err := o.ecsClient.TaskDefinition(o.appName, o.envName, o.wkldName)
if err != nil {
return orchestrator.Task{}, fmt.Errorf("get task definition: %w", err)
}
envVars, err := o.getEnvVars(ctx, td)
if err != nil {
return orchestrator.Task{}, fmt.Errorf("get env vars: %w", err)
}
if o.useTaskRole {
taskRoleCredsVars, err := o.taskRoleCredentials(ctx)
if err != nil {
return orchestrator.Task{}, fmt.Errorf("retrieve task role credentials: %w", err)
}
// overwrite environment variables
for ctr := range envVars {
for k, v := range taskRoleCredsVars {
envVars[ctr][k] = envVarValue{
Value: v,
Secret: true,
}
}
}
}
containerDeps := o.getContainerDependencies(td)
task := orchestrator.Task{
Containers: make(map[string]orchestrator.ContainerDefinition, len(td.ContainerDefinitions)),
}
if o.proxy {
pauseSecrets, err := sessionEnvVars(ctx, o.envManagerSess)
if err != nil {
return orchestrator.Task{}, fmt.Errorf("get pause container secrets: %w", err)
}
task.PauseSecrets = pauseSecrets
}
for _, ctr := range td.ContainerDefinitions {
name := aws.StringValue(ctr.Name)
def := orchestrator.ContainerDefinition{
ImageURI: aws.StringValue(ctr.Image),
EnvVars: envVars[name].EnvVars(),
Secrets: envVars[name].Secrets(),
Ports: make(map[string]string, len(ctr.PortMappings)),
IsEssential: containerDeps[name].isEssential,
DependsOn: containerDeps[name].dependsOn,
}
for _, port := range ctr.PortMappings {
hostPort := strconv.FormatInt(aws.Int64Value(port.HostPort), 10)
ctrPort := hostPort
if port.ContainerPort != nil {
ctrPort = strconv.FormatInt(aws.Int64Value(port.ContainerPort), 10)
}
for _, override := range o.portOverrides {
if override.container == ctrPort {
hostPort = override.host
break
}
}
def.Ports[hostPort] = ctrPort
}
task.Containers[name] = def
}
return task, nil
}
func (o *runLocalOpts) prepareTask(ctx context.Context) (orchestrator.Task, error) {
task, err := o.getTask(ctx)
if err != nil {
return orchestrator.Task{}, fmt.Errorf("get task: %w", err)
}
mft, _, err := workloadManifest(&workloadManifestInput{
name: o.wkldName,
appName: o.appName,
envName: o.envName,
ws: o.ws,
interpolator: o.newInterpolator(o.appName, o.envName),
unmarshal: o.unmarshal,
sess: o.envManagerSess,
})
if err != nil {
return orchestrator.Task{}, err
}
containerURIs, err := o.buildContainerImages(mft)
if err != nil {
return orchestrator.Task{}, fmt.Errorf("build images: %w", err)
}
// replace built images with the local built URI
for name, uri := range containerURIs {
ctr, ok := task.Containers[name]
if !ok {
return orchestrator.Task{}, fmt.Errorf("built an image for %q, which doesn't exist in the task", name)
}
ctr.ImageURI = uri
task.Containers[name] = ctr
}
containerDeps := manifest.ContainerDependencies(mft.Manifest())
for name, dep := range containerDeps {
ctr, ok := task.Containers[name]
if !ok {
return orchestrator.Task{}, fmt.Errorf("missing container: %q is listed as a dependency, which doesn't exist in the task", name)
}
ctr.IsEssential = dep.IsEssential
ctr.DependsOn = dep.DependsOn
task.Containers[name] = ctr
}
return task, nil
}
func (o *runLocalOpts) filterDockerExcludes() {
wsPath := o.ws.Path()
result := []string{}
// filter out excludes to the copilot directory, we always want to watch these files
copilotDirPath := filepath.ToSlash(filepath.Join(wsPath, workspace.CopilotDirName))
for _, exclude := range o.dockerExcludes {
if !strings.HasPrefix(filepath.ToSlash(exclude), copilotDirPath) {
result = append(result, exclude)
}
}
o.dockerExcludes = result
}
func (o *runLocalOpts) watchLocalFiles(stopCh <-chan struct{}) (<-chan interface{}, <-chan error, error) {
workspacePath := o.ws.Path()
watchCh := make(chan interface{})
watchErrCh := make(chan error)
watcher, err := o.newRecursiveWatcher()
if err != nil {
return nil, nil, fmt.Errorf("file: %w", err)
}
if err = watcher.Add(workspacePath); err != nil {
return nil, nil, err
}
watcherEvents := watcher.Events()
watcherErrors := watcher.Errors()
debounceTimer := time.NewTimer(o.debounceTime)
debounceTimerRunning := false
if !debounceTimer.Stop() {
// flush the timer in case stop is called after the timer finishes
<-debounceTimer.C
}
go func() {
for {
select {
case <-stopCh:
watcher.Close()
return
case err, ok := <-watcherErrors:
watchErrCh <- err
if !ok {
watcher.Close()
return
}
case event, ok := <-watcherEvents:
if !ok {
watcher.Close()
return
}
// skip chmod events
if event.Has(fsnotify.Chmod) {
break
}
parent := workspacePath
suffix, _ := strings.CutPrefix(event.Name, parent+"/")
// check if any subdirectories within copilot directory are hidden
// fsnotify events are always of form /a/b/c, don't use filepath.Split as that's OS dependent
isHidden := false
for _, child := range strings.Split(suffix, "/") {
parent = filepath.Join(parent, child)
subdirHidden, err := file.IsHiddenFile(child)
if err != nil {
break
}
if subdirHidden {
isHidden = true
}
}
// skip updates from files matching .dockerignore patterns
isExcluded := false
for _, pattern := range o.dockerExcludes {
matches, err := filepath.Match(pattern, suffix)
if err != nil {
break
}
if matches {
isExcluded = true
}
}
if !isHidden && !isExcluded {
if !debounceTimerRunning {
fmt.Println("Restarting task...")
debounceTimerRunning = true
}
debounceTimer.Reset(o.debounceTime)
}
case <-debounceTimer.C:
debounceTimerRunning = false
watchCh <- nil
}
}
}()
return watchCh, watchErrCh, nil
}
func sessionEnvVars(ctx context.Context, sess *session.Session) (map[string]string, error) {
creds, err := sess.Config.Credentials.GetWithContext(ctx)
if err != nil {
return nil, fmt.Errorf("get IAM credentials: %w", err)
}
env := map[string]string{
"AWS_ACCESS_KEY_ID": creds.AccessKeyID,
"AWS_SECRET_ACCESS_KEY": creds.SecretAccessKey,
"AWS_SESSION_TOKEN": creds.SessionToken,
}
if sess.Config.Region != nil {
env["AWS_DEFAULT_REGION"] = aws.StringValue(sess.Config.Region)
env["AWS_REGION"] = aws.StringValue(sess.Config.Region)
}
return env, nil
}
func (o *runLocalOpts) taskRoleCredentials(ctx context.Context) (map[string]string, error) {
// assumeRoleMethod tries to directly call sts:AssumeRole for TaskRole using default session
// calls sts:AssumeRole through aws-sdk-go here https://github.com/aws/aws-sdk-go/blob/ac58203a9054cc9d901429bdd94edfc0a7a1de46/aws/credentials/stscreds/assume_role_provider.go#L352
assumeRoleMethod := func() (map[string]string, error) {
taskDef, err := o.ecsClient.TaskDefinition(o.appName, o.envName, o.wkldName)
if err != nil {
return nil, err
}
taskRoleSess, err := o.sessProvider.FromRole(aws.StringValue(taskDef.TaskRoleArn), o.targetEnv.Region)
if err != nil {
return nil, err
}
return sessionEnvVars(ctx, taskRoleSess)
}
// ecsExecMethod tries to use ECS Exec to retrive credentials from running container
ecsExecMethod := func() (map[string]string, error) {
svcDesc, err := o.ecsClient.DescribeService(o.appName, o.envName, o.wkldName)
if err != nil {
return nil, fmt.Errorf("describe ECS service for %s in environment %s: %w", o.wkldName, o.envName, err)
}
stdoutReader, err := o.captureStdout()
if err != nil {
return nil, err
}
defer o.releaseStdout()
// try exec on each container within the service
var wg sync.WaitGroup
containerErr := make(chan error)
for _, task := range svcDesc.Tasks {
taskID, err := awsecs.TaskID(aws.StringValue(task.TaskArn))
if err != nil {
return nil, err
}
for _, container := range task.Containers {
wg.Add(1)
containerName := aws.StringValue(container.Name)
go func() {
defer wg.Done()
err := o.ecsExecutor.ExecuteCommand(awsecs.ExecuteCommandInput{
Cluster: svcDesc.ClusterName,
Command: fmt.Sprintf("/bin/sh -c %q\n", curlContainerCredentialsCmd),
Task: taskID,
Container: containerName,
})
if err != nil {
containerErr <- fmt.Errorf("container %s in task %s: %w", containerName, taskID, err)
}
}()
}
}
// wait for containers to finish and reset stdout
containersFinished := make(chan struct{})
go func() {
wg.Wait()
o.releaseStdout()
close(containersFinished)
}()
type containerCredentialsOutput struct {
AccessKeyId string
SecretAccessKey string
Token string
}
// parse stdout to try and find credentials
credsResult := make(chan map[string]string)
parseErr := make(chan error)
go func() {
select {
case <-containersFinished:
buf, err := io.ReadAll(stdoutReader)
if err != nil {
parseErr <- err
return
}
lines := bytes.Split(buf, []byte("\n"))
var creds containerCredentialsOutput
for _, line := range lines {
err := json.Unmarshal(line, &creds)
if err != nil {
continue
}
credsResult <- map[string]string{
"AWS_ACCESS_KEY_ID": creds.AccessKeyId,
"AWS_SECRET_ACCESS_KEY": creds.SecretAccessKey,
"AWS_SESSION_TOKEN": creds.Token,
}
return
}
parseErr <- errors.New("all containers failed to retrieve credentials")
case <-ctx.Done():
return
}
}()
var containerErrs []error
for {
select {
case creds := <-credsResult:
return creds, nil
case <-ctx.Done():
return nil, ctx.Err()
case err := <-parseErr:
return nil, errors.Join(append([]error{err}, containerErrs...)...)
case err := <-containerErr:
containerErrs = append(containerErrs, err)
}
}
}
credentialsChain := []func() (map[string]string, error){
assumeRoleMethod,
ecsExecMethod,
}
credentialsChainWrappedErrs := []string{
"assume role",
"ecs exec",
}
// return TaskRole credentials from first successful method
var errs []error
for errIndex, method := range credentialsChain {
vars, err := method()
if err == nil {
return vars, nil
}
errs = append(errs, fmt.Errorf("%s: %w", credentialsChainWrappedErrs[errIndex], err))
}
return nil, &errTaskRoleRetrievalFailed{errs}
}
type containerEnv map[string]envVarValue
type envVarValue struct {
Value string
Secret bool
Override bool
}
func (c containerEnv) EnvVars() map[string]string {
if c == nil {
return nil
}
out := make(map[string]string)
for k, v := range c {
if !v.Secret {
out[k] = v.Value
}
}
return out
}
func (c containerEnv) Secrets() map[string]string {
if c == nil {
return nil
}
out := make(map[string]string)
for k, v := range c {
if v.Secret {
out[k] = v.Value
}
}
return out
}
// getEnvVars uses env overrides passed by flags and environment variables/secrets
// specified in the Task Definition to return a set of environment varibles for each
// container defined in the TaskDefinition. The returned map is a map of container names,
// each of which contains a mapping of key->envVarValue, which defines if the variable is a secret or not.
func (o *runLocalOpts) getEnvVars(ctx context.Context, taskDef *awsecs.TaskDefinition) (map[string]containerEnv, error) {
envVars := make(map[string]containerEnv, len(taskDef.ContainerDefinitions))
for _, ctr := range taskDef.ContainerDefinitions {
envVars[aws.StringValue(ctr.Name)] = make(map[string]envVarValue)
}
for _, e := range taskDef.EnvironmentVariables() {
envVars[e.Container][e.Name] = envVarValue{
Value: e.Value,
}
}
if err := o.fillEnvOverrides(envVars); err != nil {
return nil, fmt.Errorf("parse env overrides: %w", err)
}
if err := o.fillSecrets(ctx, envVars, taskDef); err != nil {
return nil, fmt.Errorf("get secrets: %w", err)
}
// inject session variables if they haven't been already set
sessionVars, err := sessionEnvVars(ctx, o.sess)
if err != nil {
return nil, err
}
for ctr := range envVars {
for k, v := range sessionVars {
if _, ok := envVars[ctr][k]; !ok {
envVars[ctr][k] = envVarValue{
Value: v,
Secret: true,
}
}
}
}
return envVars, nil
}
// fillEnvOverrides parses environment variable overrides passed via flag.
// The expected format of the flag values is KEY=VALUE, with an optional container name
// in the format of [containerName]:KEY=VALUE. If the container name is omitted,
// the environment variable override is applied to all containers in the task definition.
func (o *runLocalOpts) fillEnvOverrides(envVars map[string]containerEnv) error {
for k, v := range o.envOverrides {
if !strings.Contains(k, ":") {
// apply override to all containers
for ctr := range envVars {
envVars[ctr][k] = envVarValue{
Value: v,
Override: true,
}
}
continue
}
// only apply override to the specified container
split := strings.SplitN(k, ":", 2)
ctr, key := split[0], split[1] // len(split) will always be 2 since we know there is a ":"
if _, ok := envVars[ctr]; !ok {
return fmt.Errorf("%q targets invalid container", k)
}
envVars[ctr][key] = envVarValue{
Value: v,
Override: true,
}
}
return nil
}
// fillSecrets collects non-overridden secrets from the task definition and
// makes requests to SSM and Secrets Manager to get their value.
func (o *runLocalOpts) fillSecrets(ctx context.Context, envVars map[string]containerEnv, taskDef *awsecs.TaskDefinition) error {
// figure out which secrets we need to get, set value to ValueFrom
unique := make(map[string]string)
for _, s := range taskDef.Secrets() {
cur, ok := envVars[s.Container][s.Name]
if cur.Override {
// ignore secrets that were overridden
continue
}
if ok {
return fmt.Errorf("secret names must be unique, but an environment variable %q already exists", s.Name)
}
envVars[s.Container][s.Name] = envVarValue{
Value: s.ValueFrom,
Secret: true,
}
unique[s.ValueFrom] = ""
}
// get value of all needed secrets
g, ctx := errgroup.WithContext(ctx)
mu := &sync.Mutex{}
mu.Lock() // lock until finished ranging over unique
for valueFrom := range unique {
valueFrom := valueFrom
g.Go(func() error {
val, err := o.getSecret(ctx, valueFrom)
if err != nil {
return fmt.Errorf("get secret %q: %w", valueFrom, err)
}
mu.Lock()
defer mu.Unlock()
unique[valueFrom] = val
return nil
})
}
mu.Unlock()
if err := g.Wait(); err != nil {
return err
}
// replace secrets with resolved values
for ctr, vars := range envVars {
for key, val := range vars {
if val.Secret {
envVars[ctr][key] = envVarValue{
Value: unique[val.Value],
Secret: true,
}
}
}
}
return nil
}
func (o *runLocalOpts) getSecret(ctx context.Context, valueFrom string) (string, error) {
// SSM secrets can be specified as parameter name instead of an ARN.
// Default to ssm if valueFrom is not an ARN.
getter := o.ssm
if parsed, err := arn.Parse(valueFrom); err == nil { // only overwrite if successful
switch parsed.Service {
case sdkssm.ServiceName:
getter = o.ssm
case sdksecretsmanager.ServiceName:
getter = o.secretsManager
default:
return "", fmt.Errorf("invalid ARN; not a SSM or Secrets Manager ARN")
}
}
return getter.GetSecretValue(ctx, valueFrom)
}
type containerDependency struct {
isEssential bool
dependsOn map[string]string
}
func (o *runLocalOpts) getContainerDependencies(taskDef *awsecs.TaskDefinition) map[string]containerDependency {
dependencies := make(map[string]containerDependency, len(taskDef.ContainerDefinitions))
for _, ctr := range taskDef.ContainerDefinitions {
dep := containerDependency{
isEssential: aws.BoolValue(ctr.Essential),
dependsOn: make(map[string]string),
}
for _, containerDep := range ctr.DependsOn {
dep.dependsOn[aws.StringValue(containerDep.ContainerName)] = strings.ToLower(aws.StringValue(containerDep.Condition))
}
dependencies[aws.StringValue(ctr.Name)] = dep
}
return dependencies
}
type hostDiscoverer struct {
ecs ecsClient
app string
env string
wkld string
rg taggedResourceGetter
rds rdsDescriber
}
func (h *hostDiscoverer) Hosts(ctx context.Context) ([]orchestrator.Host, error) {
svcs, err := h.ecs.ServiceConnectServices(h.app, h.env, h.wkld)
if err != nil {
return nil, fmt.Errorf("get service connect services: %w", err)
}
var hosts []orchestrator.Host
for _, svc := range svcs {
// find the primary deployment with Service Connect enabled
idx := slices.IndexFunc(svc.Deployments, func(dep *sdkecs.Deployment) bool {
return aws.StringValue(dep.Status) == "PRIMARY" && aws.BoolValue(dep.ServiceConnectConfiguration.Enabled)
})
if idx == -1 {
continue
}
for _, sc := range svc.Deployments[idx].ServiceConnectConfiguration.Services {
for _, alias := range sc.ClientAliases {
hosts = append(hosts, orchestrator.Host{
Name: aws.StringValue(alias.DnsName),
Port: uint16(aws.Int64Value(alias.Port)),
})
}
}
}
rdsHosts, err := h.rdsHosts(ctx)
if err != nil {
return nil, fmt.Errorf("get rds hosts: %w", err)
}
return append(hosts, rdsHosts...), nil
}
// rdsHosts gets rds endpoints for workloads tagged for this workload
// or for the environment using direct AWS SDK calls.
func (h *hostDiscoverer) rdsHosts(ctx context.Context) ([]orchestrator.Host, error) {
var hosts []orchestrator.Host
resources, err := h.rg.GetResourcesByTags(resourcegroups.ResourceTypeRDS, map[string]string{
deploy.AppTagKey: h.app,
deploy.EnvTagKey: h.env,
})
switch {
case err != nil:
return nil, fmt.Errorf("get tagged resources: %w", err)
case len(resources) == 0:
return nil, nil
}
dbFilter := &rds.Filter{
Name: aws.String("db-instance-id"),
}
clusterFilter := &rds.Filter{
Name: aws.String("db-cluster-id"),
}
for i := range resources {
// we don't want resources that belong to other services
// but we do want env level services
if wkld, ok := resources[i].Tags[deploy.ServiceTagKey]; ok && wkld != h.wkld {
continue
}
arn, err := arn.Parse(resources[i].ARN)
if err != nil {
return nil, fmt.Errorf("invalid arn %q: %w", resources[i].ARN, err)
}
switch {
case strings.HasPrefix(arn.Resource, "db:"):
dbFilter.Values = append(dbFilter.Values, aws.String(resources[i].ARN))
case strings.HasPrefix(arn.Resource, "cluster:"):
clusterFilter.Values = append(clusterFilter.Values, aws.String(resources[i].ARN))
}
}
if len(dbFilter.Values) > 0 {
err = h.rds.DescribeDBInstancesPagesWithContext(ctx, &rds.DescribeDBInstancesInput{
Filters: []*rds.Filter{dbFilter},
}, func(out *rds.DescribeDBInstancesOutput, lastPage bool) bool {
for _, db := range out.DBInstances {
if db.Endpoint != nil {
hosts = append(hosts, orchestrator.Host{
Name: aws.StringValue(db.Endpoint.Address),
Port: uint16(aws.Int64Value(db.Endpoint.Port)),
})
}
}
return true
})
if err != nil {
return nil, fmt.Errorf("describe instances: %w", err)
}
}
if len(clusterFilter.Values) > 0 {
err = h.rds.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{
Filters: []*rds.Filter{clusterFilter},
}, func(out *rds.DescribeDBClustersOutput, lastPage bool) bool {
for _, db := range out.DBClusters {
add := func(s *string) {
if s != nil {
hosts = append(hosts, orchestrator.Host{
Name: aws.StringValue(s),
Port: uint16(aws.Int64Value(db.Port)),
})
}
}
add(db.Endpoint)
add(db.ReaderEndpoint)
for i := range db.CustomEndpoints {
add(db.CustomEndpoints[i])
}
}
return true
})
if err != nil {
return nil, fmt.Errorf("describe clusters: %w", err)
}
}
return hosts, nil
}
// BuildRunLocalCmd builds the command for running a workload locally
func BuildRunLocalCmd() *cobra.Command {
vars := runLocalVars{}
cmd := &cobra.Command{
Use: "run local",
Short: "Run the workload locally.",
Long: "Run the workload locally.",
RunE: runCmdE(func(cmd *cobra.Command, args []string) error {
opts, err := newRunLocalOpts(vars)
if err != nil {
return err
}
return run(opts)
}),
Annotations: map[string]string{
"group": group.Develop,
},
}
cmd.SetUsageTemplate(cmdtemplate.Usage)
cmd.Flags().StringVarP(&vars.wkldName, nameFlag, nameFlagShort, "", workloadFlagDescription)
cmd.Flags().StringVarP(&vars.envName, envFlag, envFlagShort, "", envFlagDescription)
cmd.Flags().StringVarP(&vars.appName, appFlag, appFlagShort, tryReadingAppName(), appFlagDescription)
cmd.Flags().BoolVar(&vars.watch, watchFlag, false, watchFlagDescription)
cmd.Flags().BoolVar(&vars.useTaskRole, useTaskRoleFlag, false, useTaskRoleFlagDescription)
cmd.Flags().Var(&vars.portOverrides, portOverrideFlag, portOverridesFlagDescription)
cmd.Flags().StringToStringVar(&vars.envOverrides, envVarOverrideFlag, nil, envVarOverrideFlagDescription)
cmd.Flags().BoolVar(&vars.proxy, proxyFlag, false, proxyFlagDescription)
cmd.Flags().IPNetVar(&vars.proxyNetwork, proxyNetworkFlag, net.IPNet{
// docker uses 172.17.0.0/16 for networking by default
// so we'll default to different /16 from the 172.16.0.0/12
// private network defined by RFC 1918.
IP: net.IPv4(172, 20, 0, 0),
Mask: net.CIDRMask(16, 32),
}, proxyNetworkFlag)
return cmd
}