fast-build-update-tool/internal/tools/ssh_enabler.go (162 lines of code) (raw):
package tools
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"regexp"
"strings"
"time"
"github.com/aws/amazon-gamelift-toolkit/fast-build-update-tool/internal/config"
"github.com/aws/amazon-gamelift-toolkit/fast-build-update-tool/internal/gamelift"
"golang.org/x/crypto/ssh"
)
// SSHEnabler is used to enable, and configure SSH on a remote instance.
// SSH is enabled over AWS SSM. We use SSH along with SSM so that we can upload files to the instance.
// The default SSM configuration for GameLIft does not support the SSH proxy flow.
type SSHEnabler struct {
logger *slog.Logger
instance *gamelift.Instance
instanceAccessGetter GameLiftInstanceAccessGetter
clientPublicKey string
isNewCommandOutput func(output string) bool
commandsToRun []string
pty PTY
}
// NewSSHEnabler builds a new SSHEnabler for the target instance
func NewSSHEnabler(logger *slog.Logger, instance *gamelift.Instance, instanceAccessGetter GameLiftInstanceAccessGetter, localPublicKey ssh.PublicKey, sshPort int32) (*SSHEnabler, error) {
localPublicKeyStr := convertPublicKeyToString(localPublicKey)
pty, err := newPtyCommandRunner()
if err != nil {
return nil, err
}
var updateCommands []string
var isNewCommandOutput func(output string) bool
switch instance.OperatingSystem {
case config.OperatingSystemWindows:
updateCommands = windowsSSHEnableCommands(localPublicKeyStr, sshPort)
isNewCommandOutput = IsNewCommandOutputWindows
case config.OperatingSystemLinux:
updateCommands = linuxSSHEnableCommands(localPublicKeyStr)
isNewCommandOutput = IsNewCommandOutputLinux
default:
return nil, config.UnknownOperatingSystemError(fmt.Sprint(instance.OperatingSystem))
}
enabler := &SSHEnabler{
logger: logger.With("context", "SSHEnabler"),
pty: pty,
instance: instance,
instanceAccessGetter: instanceAccessGetter,
clientPublicKey: localPublicKeyStr,
isNewCommandOutput: isNewCommandOutput,
commandsToRun: updateCommands,
}
return enabler, enabler.Validate()
}
func convertPublicKeyToString(key ssh.PublicKey) string {
return string(bytes.TrimSuffix(ssh.MarshalAuthorizedKey(key), []byte{'\n'}))
}
func (s *SSHEnabler) Validate() error {
// Verify we have the AWS CLI in the path
if err := verifyExe(awsCommand); err != nil {
return err
}
// Verify we have the session manager plugin in the path
if err := verifyExe(sessionManagerCommand); err != nil {
return err
}
return nil
}
// Enable enable SSH on the remote instance
func (s *SSHEnabler) Enable(ctx context.Context) (ssh.PublicKey, error) {
defer s.pty.Cleanup()
// Get remote instance access credentials
accessCredentials, err := s.instanceAccessGetter.GetInstanceAccess(ctx, s.instance.FleetId, s.instance.InstanceId)
if err != nil {
return nil, err
}
// Add AWS access credential environment variables
env := os.Environ()
env = append(env, envVar("AWS_REGION", s.instance.Region))
env = append(env, envVar("AWS_ACCESS_KEY_ID", accessCredentials.AccessKeyId))
env = append(env, envVar("AWS_SECRET_ACCESS_KEY", accessCredentials.SecretAccessKey))
env = append(env, envVar("AWS_SESSION_TOKEN", accessCredentials.SessionToken))
err = s.pty.Start("aws", []string{"ssm", "start-session", "--target", s.instance.InstanceId}, env)
if err != nil {
return nil, err
}
// Channel used to let us know when we can send the next command to the SSM session
commandReady := make(chan int, 1)
// Channel used to let us know when the SSM session has written out the contents of the remote server's public key
sshKeyReady := make(chan string, 1)
// Set up an io.Writer to handle the remote output of the SSM session
ioWriter := &ptyWriter{
logger: s.logger,
commandReady: commandReady,
commandsToAccept: len(s.commandsToRun),
sshKeyReady: sshKeyReady,
isNewCommandOutput: s.isNewCommandOutput,
clientPublicKey: s.clientPublicKey,
}
// Start a goroutine to actually send the commands to the remote session
go func() {
for i, command := range s.commandsToRun {
s.logger.Debug("waiting to run ssh enable command", "commandNumber", i)
<-commandReady
s.logger.Debug("running ssh enable command", "commandNumber", i)
err := s.pty.RunCommand(command)
if err != nil {
s.logger.Error("error running remote command", "error", err)
return
}
time.Sleep(200 * time.Millisecond)
}
}()
// Start a goroutine to copy the output from the SSM session to our writer
go func() {
_, err := io.Copy(ioWriter, s.pty.Reader())
if err != nil && !errors.Is(err, os.ErrClosed) && errors.Is(err, io.EOF) {
s.logger.Warn("error copying pty commands from remote instance", "err", err)
}
}()
// Wait for the SSM session to finish
err = s.pty.Wait()
if err != nil {
return nil, err
}
// Read the remote public SSH key out of the channel, and parse it
sshKey := <-sshKeyReady
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(sshKey))
if err != nil {
return nil, fmt.Errorf("error parsing remote public key %s %w", sshKey, err)
}
return pubKey, nil
}
// ptyWriter is used to handle writing, and parsing output from the remote SSM session
type ptyWriter struct {
logger *slog.Logger
commandReady chan int
sshKeyReady chan string
commandsToAccept int
commandsRun int
isNewCommandOutput func(output string) bool
clientPublicKey string
}
func (w *ptyWriter) Write(p []byte) (int, error) {
terminalOutputStr := string(p)
// If we have a new command input (eg. `sh $`), let the channel know we can accept the next command.
if w.isNewCommandOutput(terminalOutputStr) && w.commandsRun < w.commandsToAccept {
w.commandsRun = w.commandsRun + 1
w.commandReady <- w.commandsRun
// Everything has run, close the channel
if w.commandsToAccept <= w.commandsRun {
close(w.commandReady)
}
}
// We'll cat the public key file, make sure we capture the output, we need this to connect to the server later on
match := FindED25519PublicKey(terminalOutputStr)
if match != "" && !strings.Contains(w.clientPublicKey, match) {
w.logger.Debug("found server public key", "key", match)
w.sshKeyReady <- match
}
return len(p), nil
}
var (
publicKeyRegex = regexp.MustCompile("ssh-ed25519 ([A-Za-z0-9+/=]+)")
)
func FindED25519PublicKey(s string) string {
return publicKeyRegex.FindString(s)
}
func envVar(key, value string) string {
return fmt.Sprintf("%s=%s", key, value)
}
const (
awsCommand = "aws"
sessionManagerCommand = "session-manager-plugin"
)
//go:generate moq -skip-ensure -out ./moq_gamelift_instance_access_getter_test.go . GameLiftInstanceAccessGetter
type GameLiftInstanceAccessGetter interface {
GetInstanceAccess(ctx context.Context, fleetId string, instanceId string) (*gamelift.InstanceAccessCredentials, error)
}