helpers/ssh/ssh_command.go (151 lines of code) (raw):
package ssh
import (
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
"gitlab.com/gitlab-org/gitlab-runner/common"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)
type Client struct {
common.SshConfig
ConnectRetries int
client *ssh.Client
}
type Command struct {
Command string
Stdin string
Stdout io.Writer
Stderr io.Writer
}
type ExitError struct {
Inner error
}
func (e *ExitError) Error() string {
if e.Inner == nil {
return "error"
}
return e.Inner.Error()
}
func (e *ExitError) ExitCode() int {
var cryptoExitError *ssh.ExitError
if errors.As(e.Inner, &cryptoExitError) {
return cryptoExitError.ExitStatus()
}
return 0
}
func (s *Client) getSSHKey(identityFile string) (key ssh.Signer, err error) {
buf, err := os.ReadFile(identityFile)
if err != nil {
return nil, err
}
key, err = ssh.ParsePrivateKey(buf)
return key, err
}
func (s *Client) getSSHAuthMethods() ([]ssh.AuthMethod, error) {
var methods []ssh.AuthMethod
methods = append(methods, ssh.Password(s.Password))
if s.IdentityFile != "" {
key, err := s.getSSHKey(s.IdentityFile)
if err != nil {
return nil, err
}
methods = append(methods, ssh.PublicKeys(key))
}
return methods, nil
}
func getHostKeyCallback(config common.SshConfig) (ssh.HostKeyCallback, error) {
if config.ShouldDisableStrictHostKeyChecking() {
return ssh.InsecureIgnoreHostKey(), nil
}
if config.KnownHostsFile == "" {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("user home directory: %w", err)
}
config.KnownHostsFile = filepath.Join(homeDir, ".ssh", "known_hosts")
}
return knownhosts.New(config.KnownHostsFile)
}
func (s *Client) Connect() error {
if s.Host == "" {
s.Host = "localhost"
}
if s.User == "" {
s.User = "root"
}
if s.Port == "" {
s.Port = "22"
}
methods, err := s.getSSHAuthMethods()
if err != nil {
return fmt.Errorf("getting SSH authentication methods: %w", err)
}
config := &ssh.ClientConfig{
User: s.User,
Auth: methods,
}
hostKeyCallback, err := getHostKeyCallback(s.SshConfig)
if err != nil {
return fmt.Errorf("getting host key callback: %w", err)
}
config.HostKeyCallback = hostKeyCallback
connectRetries := s.ConnectRetries
if connectRetries == 0 {
connectRetries = 3
}
var finalError error
for i := 0; i < connectRetries; i++ {
client, err := ssh.Dial("tcp", s.Host+":"+s.Port, config)
if err == nil {
s.client = client
return nil
}
time.Sleep(sshRetryInterval * time.Second)
finalError = fmt.Errorf("ssh Dial() error: %w", err)
}
return finalError
}
func (s *Client) Run(ctx context.Context, cmd Command) error {
if s.client == nil {
return errors.New("not connected")
}
session, err := s.client.NewSession()
if err != nil {
return err
}
defer func() { _ = session.Close() }()
session.Stdin = strings.NewReader(cmd.Stdin)
session.Stdout = cmd.Stdout
session.Stderr = cmd.Stderr
err = session.Start(cmd.Command)
if err != nil {
return err
}
waitCh := make(chan error)
go func() {
err := session.Wait()
if _, ok := err.(*ssh.ExitError); ok {
err = &ExitError{Inner: err}
}
waitCh <- err
}()
select {
case <-ctx.Done():
_ = session.Signal(ssh.SIGKILL)
_ = session.Close()
return <-waitCh
case err := <-waitCh:
return err
}
}
func (s *Client) Cleanup() {
if s.client != nil {
_ = s.client.Close()
}
}