func()

in fast-build-update-tool/internal/tools/ssh_enabler.go [91:167]


func (s *SSHEnabler) Enable(ctx context.Context) (ssh.PublicKey, error) {
	defer s.pty.Cleanup()

	// Get remote instance access credentials
	accessCredentials, err := s.instanceAccessGetter.GetInstanceAccess(ctx, s.instance.FleetId, s.instance.InstanceId)
	if err != nil {
		return nil, err
	}

	// Add AWS access credential environment variables
	env := os.Environ()
	env = append(env, envVar("AWS_REGION", s.instance.Region))
	env = append(env, envVar("AWS_ACCESS_KEY_ID", accessCredentials.AccessKeyId))
	env = append(env, envVar("AWS_SECRET_ACCESS_KEY", accessCredentials.SecretAccessKey))
	env = append(env, envVar("AWS_SESSION_TOKEN", accessCredentials.SessionToken))

	err = s.pty.Start("aws", []string{"ssm", "start-session", "--target", s.instance.InstanceId}, env)
	if err != nil {
		return nil, err
	}

	// Channel used to let us know when we can send the next command to the SSM session
	commandReady := make(chan int, 1)

	// Channel used to let us know when the SSM session has written out the contents of the remote server's public key
	sshKeyReady := make(chan string, 1)

	// Set up an io.Writer to handle the remote output of the SSM session
	ioWriter := &ptyWriter{
		logger:             s.logger,
		commandReady:       commandReady,
		commandsToAccept:   len(s.commandsToRun),
		sshKeyReady:        sshKeyReady,
		isNewCommandOutput: s.isNewCommandOutput,
		clientPublicKey:    s.clientPublicKey,
	}

	// Start a goroutine to actually send the commands to the remote session
	go func() {
		for i, command := range s.commandsToRun {
			s.logger.Debug("waiting to run ssh enable command", "commandNumber", i)
			<-commandReady
			s.logger.Debug("running ssh enable command", "commandNumber", i)

			err := s.pty.RunCommand(command)
			if err != nil {
				s.logger.Error("error running remote command", "error", err)
				return
			}

			time.Sleep(200 * time.Millisecond)
		}
	}()

	// Start a goroutine to copy the output from the SSM session to our writer
	go func() {
		_, err := io.Copy(ioWriter, s.pty.Reader())
		if err != nil && !errors.Is(err, os.ErrClosed) && errors.Is(err, io.EOF) {
			s.logger.Warn("error copying pty commands from remote instance", "err", err)
		}
	}()

	// Wait for the SSM session to finish
	err = s.pty.Wait()
	if err != nil {
		return nil, err
	}

	// Read the remote public SSH key out of the channel, and parse it
	sshKey := <-sshKeyReady
	pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(sshKey))
	if err != nil {
		return nil, fmt.Errorf("error parsing remote public key %s %w", sshKey, err)
	}

	return pubKey, nil
}