fast-build-update-tool/internal/tools/ssh_command_runner.go (81 lines of code) (raw):
package tools
import (
"context"
"fmt"
"log/slog"
"net"
"os"
"path/filepath"
"github.com/aws/amazon-gamelift-toolkit/fast-build-update-tool/internal/config"
"github.com/aws/amazon-gamelift-toolkit/fast-build-update-tool/internal/gamelift"
"golang.org/x/crypto/ssh"
)
// SSHCommandRunner is used to run a shell script on a remote instance over SSH
type SSHCommandRunner struct {
logger *slog.Logger
sshPort int32
instanceIpAddress string
instanceId string
sshKey ssh.Signer
updateScriptCommand string
remoteUserName string
}
// NewSSHCommandRunner build a new SSHCommandRunner for the provided script, and instance
func NewSSHCommandRunner(logger *slog.Logger, localUpdateScriptPath string, sshPort int32, sshKey ssh.Signer, instance *gamelift.Instance) (*SSHCommandRunner, error) {
updateScriptCommand, err := generateUpdateScriptCommand(localUpdateScriptPath, instance)
if err != nil {
return nil, err
}
return &SSHCommandRunner{
logger: logger.With("context", "SSHCommandRunner"),
sshPort: sshPort,
instanceIpAddress: instance.IpAddress,
instanceId: instance.InstanceId,
sshKey: sshKey,
updateScriptCommand: updateScriptCommand,
remoteUserName: string(config.RemoteUserForOperatingSystem(instance.OperatingSystem)),
}, nil
}
// Run will open an SSH connection to the remote instance, and run a script command on it
func (s *SSHCommandRunner) Run(ctx context.Context, remotePublicKey ssh.PublicKey) error {
// Set up the SSH connection to the remote instance
client, err := ssh.Dial("tcp", net.JoinHostPort(s.instanceIpAddress, fmt.Sprintf("%d", s.sshPort)), &ssh.ClientConfig{
User: s.remoteUserName,
HostKeyCallback: ssh.FixedHostKey(remotePublicKey),
HostKeyAlgorithms: []string{remotePublicKey.Type()},
Auth: []ssh.AuthMethod{
ssh.PublicKeys(s.sshKey),
},
})
if err != nil {
return fmt.Errorf("error dialing ssh connection: %w", err)
}
defer client.Close()
session, err := client.NewSession()
if err != nil {
return fmt.Errorf("error starting ssh session: %w", err)
}
defer session.Close()
logFilePath := config.GetLogPathForFile(fmt.Sprintf("%s-ssh-command.log", s.instanceId))
// Set up a log file so we log out any remote output we get from the instance
logFile, err := os.OpenFile(logFilePath, os.O_RDWR|os.O_CREATE, 0666)
if err != nil {
return fmt.Errorf("error creating log file for ssh command runner: %w", err)
}
defer logFile.Close()
session.Stdout = logFile
session.Stderr = config.NewErrorLogger("SSHCommandRunner")
slog.Debug("running command on instance", "command", s.updateScriptCommand)
// Run the actual update command on the instance
err = session.Run(s.updateScriptCommand)
if err != nil {
return fmt.Errorf("error running server update script: %w; Check logs in %s for more information", err, logFilePath)
}
return nil
}
// generateUpdateScriptCommand will generate the remote command used to run the update script we have generated for a specific instance
func generateUpdateScriptCommand(localUpdateScriptPath string, instance *gamelift.Instance) (string, error) {
remoteUploadDirectory := string(config.RemoteUploadDirectoryForOperatingSystem(instance.OperatingSystem))
remoteUpdateScript := remoteUploadDirectory + filepath.Base(localUpdateScriptPath)
switch instance.OperatingSystem {
case config.OperatingSystemWindows:
return fmt.Sprintf("powershell.exe -ExecutionPolicy Bypass -File %s", remoteUpdateScript), nil
case config.OperatingSystemLinux:
return fmt.Sprintf("chmod +x %s && %s", remoteUpdateScript, remoteUpdateScript), nil
default:
return "", config.UnknownOperatingSystemError(fmt.Sprint(instance.OperatingSystem))
}
}