pkg/testing/ssh/client.go (239 lines of code) (raw):
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License 2.0;
// you may not use this file except in compliance with the Elastic License 2.0.
package ssh
import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"strings"
"time"
"golang.org/x/crypto/ssh"
)
type logger interface {
// Logf logs the message for this runner.
Logf(format string, args ...any)
}
type sshClient struct {
ip string
username string
auth ssh.AuthMethod
logger logger
c *ssh.Client
}
// NewClient creates a new SSH client connection to the host.
func NewClient(ip string, username string, sshAuth ssh.AuthMethod, logger logger) SSHClient {
return &sshClient{
ip: ip,
username: username,
auth: sshAuth,
logger: logger,
}
}
// Connect connects to the host.
func (s *sshClient) Connect(ctx context.Context) error {
var lastErr error
config := &ssh.ClientConfig{
User: s.username,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // it's the tests framework test
Auth: []ssh.AuthMethod{s.auth},
Timeout: 30 * time.Second,
}
addr := net.JoinHostPort(s.ip, "22")
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return fmt.Errorf("unable to resolve ssh address %q :%w", addr, err)
}
delay := 1 * time.Second
for {
if ctx.Err() != nil {
if lastErr == nil {
return ctx.Err()
}
return lastErr
}
if lastErr != nil {
s.logger.Logf("ssh connect error: %q, will try again in %s", lastErr, delay)
time.Sleep(delay)
delay = 2 * delay
}
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
lastErr = fmt.Errorf("error dialing tcp address %q :%w", addr, err)
continue
}
err = conn.SetKeepAlive(true)
if err != nil {
_ = conn.Close()
lastErr = fmt.Errorf("error setting TCP keepalive for ssh to %q :%w", addr, err)
continue
}
err = conn.SetKeepAlivePeriod(config.Timeout)
if err != nil {
_ = conn.Close()
lastErr = fmt.Errorf("error setting TCP keepalive period for ssh to %q :%w", addr, err)
continue
}
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
_ = conn.Close()
lastErr = fmt.Errorf("error NewClientConn for ssh to %q :%w", addr, err)
continue
}
s.c = ssh.NewClient(sshConn, chans, reqs)
return nil
}
}
// ConnectWithTimeout connects to the host with a timeout.
func (s *sshClient) ConnectWithTimeout(ctx context.Context, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return s.Connect(ctx)
}
// Close closes the client.
func (s *sshClient) Close() error {
if s.c != nil {
err := s.c.Close()
s.c = nil
return err
}
return nil
}
// Reconnect disconnects and reconnected to the host.
func (s *sshClient) Reconnect(ctx context.Context) error {
_ = s.Close()
return s.Connect(ctx)
}
// ReconnectWithTimeout disconnects and reconnected to the host with a timeout.
func (s *sshClient) ReconnectWithTimeout(ctx context.Context, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return s.Reconnect(ctx)
}
// NewSession opens a new Session for this host.
func (s *sshClient) NewSession() (*ssh.Session, error) {
return s.c.NewSession()
}
// Exec runs a command on the host.
func (s *sshClient) Exec(ctx context.Context, cmd string, args []string, stdin io.Reader) ([]byte, []byte, error) {
if ctx.Err() != nil {
return nil, nil, ctx.Err()
}
var session *ssh.Session
cmdArgs := []string{cmd}
cmdArgs = append(cmdArgs, args...)
cmdStr := strings.Join(cmdArgs, " ")
session, err := s.NewSession()
if err != nil {
s.logger.Logf("new session failed: %q, trying reconnect", err)
lErr := s.Reconnect(ctx)
if lErr != nil {
return nil, nil, fmt.Errorf("ssh reconnect failed: %w, after new session failed: %w", lErr, err)
}
session, lErr = s.NewSession()
if lErr != nil {
return nil, nil, fmt.Errorf("new session failed after reconnect: %w, original new session failure was: %w", lErr, err)
}
}
defer session.Close()
var stdout bytes.Buffer
var stderr bytes.Buffer
session.Stdout = &stdout
session.Stderr = &stderr
if stdin != nil {
session.Stdin = stdin
}
err = session.Run(cmdStr)
if err != nil {
return stdout.Bytes(), stderr.Bytes(), fmt.Errorf("could not run %q though SSH: %w",
cmdStr, err)
}
return stdout.Bytes(), stderr.Bytes(), err
}
// ExecWithRetry runs the command on loop waiting the interval between calls
func (s *sshClient) ExecWithRetry(ctx context.Context, cmd string, args []string, interval time.Duration) ([]byte, []byte, error) {
var lastErr error
var lastStdout []byte
var lastStderr []byte
for {
// the length of time for running the command is not blocked on the interval
// don't create a new context with the interval as its timeout
stdout, stderr, err := s.Exec(ctx, cmd, args, nil)
if err == nil {
return stdout, stderr, nil
}
s.logger.Logf("ssh exec error: %q, will try again in %s", err, interval)
lastErr = err
lastStdout = stdout
lastStderr = stderr
// wait for the interval or ctx to be cancelled
select {
case <-ctx.Done():
if lastErr != nil {
return lastStdout, lastStderr, lastErr
}
return nil, nil, ctx.Err()
case <-time.After(interval):
}
}
}
// Copy copies the filePath to the host at dest.
func (s *sshClient) Copy(filePath string, dest string) error {
f, err := os.Open(filePath)
if err != nil {
return err
}
defer f.Close()
fs, err := f.Stat()
if err != nil {
return err
}
session, err := s.NewSession()
if err != nil {
return err
}
defer session.Close()
w, err := session.StdinPipe()
if err != nil {
return err
}
cmd := fmt.Sprintf("scp -t %s", dest)
if err := session.Start(cmd); err != nil {
_ = w.Close()
return err
}
errCh := make(chan error)
go func() {
errCh <- session.Wait()
}()
_, err = fmt.Fprintf(w, "C%#o %d %s\n", fs.Mode().Perm(), fs.Size(), dest)
if err != nil {
_ = w.Close()
<-errCh
return err
}
_, err = io.Copy(w, f)
if err != nil {
_ = w.Close()
<-errCh
return err
}
_, _ = fmt.Fprint(w, "\x00")
_ = w.Close()
return <-errCh
}
// GetFileContents returns the file content.
func (s *sshClient) GetFileContents(ctx context.Context, filename string, opts ...FileContentsOpt) ([]byte, error) {
var stdout bytes.Buffer
err := s.GetFileContentsOutput(ctx, filename, &stdout, opts...)
if err != nil {
return nil, err
}
return stdout.Bytes(), nil
}
// GetFileContentsOutput returns the file content writing into output.
func (s *sshClient) GetFileContentsOutput(ctx context.Context, filename string, output io.Writer, opts ...FileContentsOpt) error {
if ctx.Err() != nil {
return ctx.Err()
}
var fco fileContentsOpts
fco.command = "cat"
for _, opt := range opts {
opt(&fco)
}
session, err := s.NewSession()
if err != nil {
return err
}
defer session.Close()
session.Stdout = output
err = session.Run(fmt.Sprintf("%s %s", fco.command, filename))
if err != nil {
return err
}
return nil
}