in fast-build-update-tool/internal/tools/ssh_enabler.go [91:167]
func (s *SSHEnabler) Enable(ctx context.Context) (ssh.PublicKey, error) {
defer s.pty.Cleanup()
// Get remote instance access credentials
accessCredentials, err := s.instanceAccessGetter.GetInstanceAccess(ctx, s.instance.FleetId, s.instance.InstanceId)
if err != nil {
return nil, err
}
// Add AWS access credential environment variables
env := os.Environ()
env = append(env, envVar("AWS_REGION", s.instance.Region))
env = append(env, envVar("AWS_ACCESS_KEY_ID", accessCredentials.AccessKeyId))
env = append(env, envVar("AWS_SECRET_ACCESS_KEY", accessCredentials.SecretAccessKey))
env = append(env, envVar("AWS_SESSION_TOKEN", accessCredentials.SessionToken))
err = s.pty.Start("aws", []string{"ssm", "start-session", "--target", s.instance.InstanceId}, env)
if err != nil {
return nil, err
}
// Channel used to let us know when we can send the next command to the SSM session
commandReady := make(chan int, 1)
// Channel used to let us know when the SSM session has written out the contents of the remote server's public key
sshKeyReady := make(chan string, 1)
// Set up an io.Writer to handle the remote output of the SSM session
ioWriter := &ptyWriter{
logger: s.logger,
commandReady: commandReady,
commandsToAccept: len(s.commandsToRun),
sshKeyReady: sshKeyReady,
isNewCommandOutput: s.isNewCommandOutput,
clientPublicKey: s.clientPublicKey,
}
// Start a goroutine to actually send the commands to the remote session
go func() {
for i, command := range s.commandsToRun {
s.logger.Debug("waiting to run ssh enable command", "commandNumber", i)
<-commandReady
s.logger.Debug("running ssh enable command", "commandNumber", i)
err := s.pty.RunCommand(command)
if err != nil {
s.logger.Error("error running remote command", "error", err)
return
}
time.Sleep(200 * time.Millisecond)
}
}()
// Start a goroutine to copy the output from the SSM session to our writer
go func() {
_, err := io.Copy(ioWriter, s.pty.Reader())
if err != nil && !errors.Is(err, os.ErrClosed) && errors.Is(err, io.EOF) {
s.logger.Warn("error copying pty commands from remote instance", "err", err)
}
}()
// Wait for the SSM session to finish
err = s.pty.Wait()
if err != nil {
return nil, err
}
// Read the remote public SSH key out of the channel, and parse it
sshKey := <-sshKeyReady
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(sshKey))
if err != nil {
return nil, fmt.Errorf("error parsing remote public key %s %w", sshKey, err)
}
return pubKey, nil
}