internal/pkg/docker/orchestrator/orchestrator.go (505 lines of code) (raw):
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package orchestrator
import (
"context"
_ "embed"
"errors"
"fmt"
"io"
"maps"
"net"
"os"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/aws/copilot-cli/internal/pkg/docker/dockerengine"
"github.com/aws/copilot-cli/internal/pkg/graph"
"github.com/aws/copilot-cli/internal/pkg/term/color"
"github.com/aws/copilot-cli/internal/pkg/term/log"
"golang.org/x/sync/errgroup"
)
// Orchestrator manages running a Task. Only a single Task
// can be running at a time for a given Orchestrator.
type Orchestrator struct {
idPrefix string
logOptions logOptionsFunc
curTask Task
curTaskID atomic.Int32
runErrs chan error
stopped chan struct{}
wg *sync.WaitGroup
actions chan action
stopOnce *sync.Once
docker DockerEngine
}
type action interface {
Do(o *Orchestrator) error
}
type logOptionsFunc func(name string, ctr ContainerDefinition) dockerengine.RunLogOptions
// DockerEngine is used by Orchestrator to manage containers.
type DockerEngine interface {
Run(context.Context, *dockerengine.RunOptions) error
IsContainerRunning(context.Context, string) (bool, error)
ContainerExitCode(ctx context.Context, containerName string) (int, error)
IsContainerHealthy(ctx context.Context, containerName string) (bool, error)
Stop(context.Context, string) error
Build(ctx context.Context, args *dockerengine.BuildArguments, w io.Writer) error
Exec(ctx context.Context, container string, out io.Writer, cmd string, args ...string) error
Rm(context.Context, string) error
}
const (
orchestratorStoppedTaskID = -1
pauseCtrTaskID = 0
)
const (
pauseCtrURI = "aws-copilot-pause"
pauseCtrTag = "latest"
)
const (
proxyPortStart = uint16(50000)
)
const (
ctrStateHealthy = "healthy"
ctrStateComplete = "complete"
ctrStateSuccess = "success"
ctrStateStart = "start"
)
//go:embed Pause-Dockerfile
var pauseDockerfile string
// New creates a new Orchestrator. idPrefix is a prefix used when
// naming containers that are run by the Orchestrator.
func New(docker DockerEngine, idPrefix string, logOptions logOptionsFunc) *Orchestrator {
return &Orchestrator{
idPrefix: idPrefix,
logOptions: logOptions,
stopped: make(chan struct{}),
docker: docker,
wg: &sync.WaitGroup{},
stopOnce: &sync.Once{},
actions: make(chan action),
runErrs: make(chan error),
}
}
// Start starts the Orchestrator. Start must be called before any other
// orchestrator functions. Errors from containers run by the Orchestrator
// or from Orchestrator actions are sent to the returned error channel.
// The returned error channel is closed after calling Stop() has
// stopped the Orchestrator. An Orchestrator should only be Started once.
func (o *Orchestrator) Start() <-chan error {
// close done when all goroutines created by Orchestrator have finished
done := make(chan struct{})
// buffered channel so that the orchestrator routine does not block and
// can always send the error from both runErrs and action.Do to errs.
errs := make(chan error, 1)
// orchestrator routine
o.wg.Add(1) // decremented by stopAction
go func() {
for {
select {
case action := <-o.actions:
if err := action.Do(o); err != nil {
errs <- err
}
case err := <-o.runErrs:
errs <- err
case <-done:
close(errs)
return
}
}
}()
go func() {
o.wg.Wait()
close(done)
}()
return errs
}
// RunTask stops the current running task and starts task.
func (o *Orchestrator) RunTask(task Task, opts ...RunTaskOption) {
r := &runTaskAction{
task: task,
}
for _, opt := range opts {
opt(r)
}
// this guarantees the following:
// - if r is pulled by the Orchestrator, any errors
// returned by it are reported by the Orchestrator.
// - if Stop() is called _before_ the Orchestrator picks up this
// action, then this action is skipped.
select {
case <-o.stopped:
case o.actions <- r:
}
}
type runTaskAction struct {
task Task
// optional vars for proxy
hosts []Host
ssmTarget string
network *net.IPNet
}
// RunTaskOption adds optional data to RunTask.
type RunTaskOption func(*runTaskAction)
// Host represents a service reachable via the network.
type Host struct {
Name string
Port uint16
}
// RunTaskWithProxy returns a RunTaskOption that sets up a proxy connection to hosts.
func RunTaskWithProxy(ssmTarget string, network net.IPNet, hosts ...Host) RunTaskOption {
return func(r *runTaskAction) {
r.ssmTarget = ssmTarget
r.hosts = hosts
r.network = &network
}
}
func (a *runTaskAction) Do(o *Orchestrator) error {
// we no longer care about errors from the old task
taskID := o.curTaskID.Add(1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// cancelCtxOnStop calls cancel if Stop() is called before ctx finishes.
o.wg.Add(1)
go func() {
defer o.wg.Done()
select {
case <-ctx.Done():
case <-o.stopped:
cancel()
}
}()
prevTask := o.curTask
o.curTask = a.task
if taskID == 1 {
if err := o.buildPauseContainer(ctx); err != nil {
return fmt.Errorf("build pause container: %w", err)
}
// start the pause container
opts := o.pauseRunOptions(a.task)
o.run(pauseCtrTaskID, opts, true, cancel)
if err := o.waitForContainerToStart(ctx, opts.ContainerName); err != nil {
return fmt.Errorf("wait for pause container to start: %w", err)
}
if len(a.hosts) > 0 {
if err := o.setupProxyConnections(ctx, opts.ContainerName, a); err != nil {
return fmt.Errorf("setup proxy connections: %w", err)
}
}
} else {
// ensure no pause container changes
prevOpts := o.pauseRunOptions(prevTask)
newOpts := o.pauseRunOptions(a.task)
if !maps.Equal(prevOpts.EnvVars, newOpts.EnvVars) ||
!maps.Equal(prevOpts.Secrets, newOpts.Secrets) ||
!maps.Equal(prevOpts.ContainerPorts, newOpts.ContainerPorts) {
return errors.New("new task requires recreating pause container")
}
if err := o.stopTask(ctx, prevTask); err != nil {
return fmt.Errorf("stop existing task: %w", err)
}
}
depGraph := buildDependencyGraph(a.task.Containers)
err := depGraph.UpwardTraversal(ctx, func(ctx context.Context, containerName string) error {
if len(a.task.Containers[containerName].DependsOn) > 0 {
if err := o.waitForContainerDependencies(ctx, containerName, a.task.Containers); err != nil {
return fmt.Errorf("wait for container %s dependencies: %w", containerName, err)
}
}
o.run(taskID, o.containerRunOptions(containerName, a.task.Containers[containerName]), a.task.Containers[containerName].IsEssential, cancel)
var errContainerExited *dockerengine.ErrContainerExited
if err := o.waitForContainerToStart(ctx, o.containerID(containerName)); err != nil && !errors.As(err, &errContainerExited) {
return fmt.Errorf("wait for container %s to start: %w", containerName, err)
}
return nil
})
if err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
return fmt.Errorf("upward traversal: %w", err)
}
return nil
}
func buildDependencyGraph(containers map[string]ContainerDefinition) *graph.LabeledGraph[string] {
var vertices []string
for vertex := range containers {
vertices = append(vertices, vertex)
}
dependencyGraph := graph.NewLabeledGraph(vertices)
for containerName, container := range containers {
for depCtr := range container.DependsOn {
dependencyGraph.Add(graph.Edge[string]{
From: containerName,
To: depCtr,
})
}
}
return dependencyGraph
}
// setupProxyConnections creates proxy connections to a.hosts in pauseContainer.
// It assumes that pauseContainer is already running. A unique proxy connection
// is created for each host (in parallel) using AWS SSM Port Forwarding through
// a.ssmTarget. Then, each connection is assigned an IP from a.network,
// starting at the bottom of the IP range. Using iptables, TCP packets destined
// for the connection's assigned IP are redirected to the connection. Finally,
// the host's name is mapped to its assigned IP in /etc/hosts.
func (o *Orchestrator) setupProxyConnections(ctx context.Context, pauseContainer string, a *runTaskAction) error {
fmt.Printf("\nSetting up proxy connections...\n")
ports := make(map[Host]uint16)
port := proxyPortStart
for i := range a.hosts {
ports[a.hosts[i]] = port
port++
}
for _, host := range a.hosts {
host := host
portForHost := ports[host]
o.wg.Add(1)
go func() {
defer o.wg.Done()
err := o.docker.Exec(context.Background(), pauseContainer, io.Discard, "aws", "ssm", "start-session",
"--target", a.ssmTarget,
"--document-name", "AWS-StartPortForwardingSessionToRemoteHost",
"--parameters", fmt.Sprintf(`{"host":["%s"],"portNumber":["%d"],"localPortNumber":["%d"]}`, host.Name, host.Port, portForHost))
if err != nil {
// report err as a runtime error from the pause container
if o.curTaskID.Load() != orchestratorStoppedTaskID {
o.runErrs <- fmt.Errorf("proxy to %v:%v: %w", host.Name, host.Port, err)
}
}
}()
}
ip := a.network.IP
for host, port := range ports {
err := o.docker.Exec(ctx, pauseContainer, io.Discard, "iptables",
"--table", "nat",
"--append", "OUTPUT",
"--destination", ip.String(),
"--protocol", "tcp",
"--match", "tcp",
"--dport", strconv.Itoa(int(host.Port)),
"--jump", "REDIRECT",
"--to-ports", strconv.Itoa(int(port)))
if err != nil {
return fmt.Errorf("modify iptables: %w", err)
}
err = o.docker.Exec(ctx, pauseContainer, io.Discard, "/bin/bash",
"-c", fmt.Sprintf(`echo %s %s >> /etc/hosts`, ip.String(), host.Name))
if err != nil {
return fmt.Errorf("update /etc/hosts: %w", err)
}
ip, err = ipv4Increment(ip, a.network)
if err != nil {
return fmt.Errorf("increment ip: %w", err)
}
fmt.Printf("Created connection to %v:%v\n", host.Name, host.Port)
}
fmt.Printf("Finished setting up proxy connections\n\n")
return nil
}
// ipv4Increment returns a copy of ip that has been incremented.
func ipv4Increment(ip net.IP, network *net.IPNet) (net.IP, error) {
// make a copy of the previous ip
cpy := make(net.IP, len(ip))
copy(cpy, ip)
ipv4 := cpy.To4()
var inc func(idx int) error
inc = func(idx int) error {
if idx == -1 {
return errors.New("max ipv4 address")
}
ipv4[idx]++
if ipv4[idx] == 0 { // overflow occured
return inc(idx - 1)
}
return nil
}
err := inc(len(ipv4) - 1)
if err != nil {
return nil, err
}
if !network.Contains(ipv4) {
return nil, fmt.Errorf("no more addresses in network")
}
return ipv4, nil
}
func (o *Orchestrator) buildPauseContainer(ctx context.Context) error {
arch := "64bit"
if strings.Contains(runtime.GOARCH, "arm") {
arch = "arm64"
}
return o.docker.Build(ctx, &dockerengine.BuildArguments{
URI: pauseCtrURI,
Tags: []string{pauseCtrTag},
DockerfileContent: pauseDockerfile,
Args: map[string]string{
"ARCH": arch,
},
}, os.Stderr)
}
// Stop stops the current running task containers and the Orchestrator. Stop is
// idempotent and safe to call multiple times. Calls to RunTask() after calling Stop
// do nothing.
func (o *Orchestrator) Stop() {
o.stopOnce.Do(func() {
close(o.stopped)
o.actions <- &stopAction{}
})
}
type stopAction struct{}
func (a *stopAction) Do(o *Orchestrator) error {
defer o.wg.Done() // for the Orchestrator
o.curTaskID.Store(orchestratorStoppedTaskID) // ignore runtime errors
fmt.Printf("\nStopping task...\n")
// collect errors since we want to try to clean up everything we can
var errs []error
if err := o.stopTask(context.Background(), o.curTask); err != nil {
errs = append(errs, err)
}
// stop pause container
fmt.Printf("Stopping and removing %q\n", "pause")
if err := o.docker.Stop(context.Background(), o.containerID("pause")); err != nil {
errs = append(errs, fmt.Errorf("stop %q: %w", "pause", err))
}
if err := o.docker.Rm(context.Background(), o.containerID("pause")); err != nil {
errs = append(errs, fmt.Errorf("remove %q: %w", "pause", err))
}
fmt.Printf("Stopped and removed %q\n", "pause")
return errors.Join(errs...)
}
// stopTask calls `docker stop` for all containers defined by task.
func (o *Orchestrator) stopTask(ctx context.Context, task Task) error {
if len(task.Containers) == 0 {
return nil
}
// errCh gets one error per container
errCh := make(chan error, len(task.Containers))
depGraph := buildDependencyGraph(task.Containers)
err := depGraph.DownwardTraversal(ctx, func(ctx context.Context, name string) error {
fmt.Printf("Stopping and removing %q\n", name)
if err := o.docker.Stop(ctx, o.containerID(name)); err != nil {
errCh <- fmt.Errorf("stop %q: %w", name, err)
return nil
}
if err := o.docker.Rm(ctx, o.containerID(name)); err != nil {
errCh <- fmt.Errorf("remove %q: %w", name, err)
return nil
}
fmt.Printf("Stopped and removed %q\n", name)
errCh <- nil
return nil
})
if err != nil {
return fmt.Errorf("downward traversal: %w", err)
}
var errs []error
for err := range errCh {
errs = append(errs, err)
if len(errs) == len(task.Containers) {
break
}
}
return errors.Join(errs...)
}
// waitForContainerToStart blocks until the container specified by id starts.
func (o *Orchestrator) waitForContainerToStart(ctx context.Context, id string) error {
for {
isRunning, err := o.docker.IsContainerRunning(ctx, id)
switch {
case err != nil:
return fmt.Errorf("check if %q is running: %w", id, err)
case isRunning:
log.Successf("Successfully started container %s\n", id)
return nil
}
select {
case <-time.After(1 * time.Second):
case <-ctx.Done():
return ctx.Err()
}
}
}
func (o *Orchestrator) waitForContainerDependencies(ctx context.Context, name string, definitions map[string]ContainerDefinition) error {
var deps []string
for depName, state := range definitions[name].DependsOn {
deps = append(deps, fmt.Sprintf("%s->%s", depName, state))
}
logMsg := strings.Join(deps, ", ")
fmt.Printf("Waiting for container %q dependencies: [%s]\n", name, color.Emphasize(logMsg))
eg, ctx := errgroup.WithContext(ctx)
for name, state := range definitions[name].DependsOn {
name, state := name, state
eg.Go(func() error {
ctrId := o.containerID(name)
ticker := time.NewTicker(700 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
case <-ctx.Done():
return ctx.Err()
}
switch state {
case ctrStateStart:
return nil
case ctrStateHealthy:
healthy, err := o.docker.IsContainerHealthy(ctx, ctrId)
if err != nil {
return fmt.Errorf("wait for container %q to be healthy: %w", ctrId, err)
}
if healthy {
log.Successf("Successfully dependency container %q reached healthy\n", ctrId)
return nil
}
case ctrStateComplete:
exitCode, err := o.docker.ContainerExitCode(ctx, ctrId)
var errContainerNotExited *dockerengine.ErrContainerNotExited
if errors.As(err, &errContainerNotExited) {
continue
}
if err != nil {
return fmt.Errorf("wait for container %q to complete: %w", ctrId, err)
}
log.Successf("%q's dependency container %q exited with code: %d\n", name, ctrId, exitCode)
return nil
case ctrStateSuccess:
exitCode, err := o.docker.ContainerExitCode(ctx, ctrId)
var errContainerNotExited *dockerengine.ErrContainerNotExited
if errors.As(err, &errContainerNotExited) {
continue
}
if err != nil {
return fmt.Errorf("wait for container %q to success: %w", ctrId, err)
}
if exitCode != 0 {
return fmt.Errorf("dependency container %q exited with non-zero exit code %d", ctrId, exitCode)
}
log.Successf("%q's dependency container %q exited with code: %d\n", name, ctrId, exitCode)
return nil
}
}
})
}
return eg.Wait()
}
// containerID returns the full ID for a container with name run by s.
func (o *Orchestrator) containerID(name string) string {
return o.idPrefix + name
}
// Task defines a set of Containers to be run together.
// Containers within a Task can talk to each other on localhost
// and are stopped and started as a group.
type Task struct {
Containers map[string]ContainerDefinition
PauseSecrets map[string]string
}
// ContainerDefinition defines information necessary to run a container.
type ContainerDefinition struct {
ImageURI string
EnvVars map[string]string
Secrets map[string]string
Ports map[string]string // host port -> container port
IsEssential bool
DependsOn map[string]string
}
// pauseRunOptions returns RunOptions for the pause container for t.
// The pause container owns the networking namespace that is shared
// among all of the containers in the task.
func (o *Orchestrator) pauseRunOptions(t Task) dockerengine.RunOptions {
opts := dockerengine.RunOptions{
ImageURI: fmt.Sprintf("%s:%s", pauseCtrURI, pauseCtrTag),
ContainerName: o.containerID("pause"),
Command: []string{"sleep", "infinity"},
ContainerPorts: make(map[string]string),
Secrets: t.PauseSecrets,
AddLinuxCapabilities: []string{"NET_ADMIN"},
Init: true,
}
for _, ctr := range t.Containers {
for hostPort, ctrPort := range ctr.Ports {
// TODO some error if host port is already defined?
opts.ContainerPorts[hostPort] = ctrPort
}
}
return opts
}
// containerRunOptions returns RunOptions for the given container.
func (o *Orchestrator) containerRunOptions(name string, ctr ContainerDefinition) dockerengine.RunOptions {
return dockerengine.RunOptions{
ImageURI: ctr.ImageURI,
ContainerName: o.containerID(name),
EnvVars: ctr.EnvVars,
Secrets: ctr.Secrets,
ContainerNetwork: o.containerID("pause"),
LogOptions: o.logOptions(name, ctr),
}
}
// run calls `docker run` using opts. Errors are only returned
// to the main Orchestrator routine if the taskID the container was run with
// matches the current taskID the Orchestrator is running.
func (o *Orchestrator) run(taskID int32, opts dockerengine.RunOptions, isEssential bool, cancel context.CancelFunc) {
o.wg.Add(1)
go func() {
defer o.wg.Done()
err := o.docker.Run(context.Background(), &opts)
// if the orchestrator has already stopped,
// we don't want to report the error
curTaskID := o.curTaskID.Load()
if curTaskID == orchestratorStoppedTaskID {
return
}
// the error is from the pause container
// or from the currently running task
if taskID == pauseCtrTaskID || taskID == curTaskID {
var errContainerExited *dockerengine.ErrContainerExited
if !isEssential && (errors.As(err, &errContainerExited) || err == nil) {
fmt.Printf("non-essential container %q stopped\n", opts.ContainerName)
return
}
if err == nil {
err = errors.New("container stopped unexpectedly")
}
// cancel context to indicate all the other go routines spawned by `graph.UpwardTarversal`.
cancel()
o.runErrs <- fmt.Errorf("run %q: %w", opts.ContainerName, err)
}
}()
}