pkg/helpers/ssh/ssh.go (184 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
package ssh
import (
"context"
"encoding/base64"
"fmt"
"net"
"os"
"path"
"strings"
"time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
kh "golang.org/x/crypto/ssh/knownhosts"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/util/retry"
)
// ExecuteRemote executes a script in a remote host.
//
// Context ctx is only enforced during the process that stablishes
// the SSH connection and creates the SSH client.
func ExecuteRemote(ctx context.Context, host *RemoteHost, script string) (combinedOutput string, err error) {
c, err := clientWithRetry(ctx, host)
if err != nil {
return "", errors.Wrap(err, "creating SSH client")
}
defer c.Close()
s, err := c.NewSession()
if err != nil {
return "", errors.Wrap(err, "creating SSH session")
}
defer s.Close()
if co, err := s.CombinedOutput(script); err != nil {
return string(co), errors.Wrapf(err, "executing script")
}
return "", nil
}
// PublicKeyAuth returns an AuthMethod that uses a ssh key pair
func PublicKeyAuth(sshPrivateKeyPath string) (ssh.AuthMethod, error) {
b, err := os.ReadFile(sshPrivateKeyPath)
if err != nil {
return nil, errors.Wrap(err, "reading ssh private key file")
}
k, err := ssh.ParsePrivateKey(b)
if err != nil {
return nil, errors.Wrap(err, "parsing ssh private key content")
}
return ssh.PublicKeys(k), nil
}
// ValidateConfig checks the JumpBox configuration
func ValidateConfig(host *JumpBox) error {
jbConfig, err := config(host.AuthConfig)
if err != nil {
return errors.Wrap(err, "creating ssh client config")
}
_, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", host.URI, host.Port), jbConfig)
if err != nil {
return errors.Wrapf(err, "dialing ssh (%s)", host.URI)
}
return nil
}
func clientWithRetry(ctx context.Context, host *RemoteHost) (*ssh.Client, error) {
// TODO Granular retry func
retryFunc := func(err error) bool {
select {
case <-ctx.Done():
return false
default:
return true
}
}
backoff := wait.Backoff{Steps: 300, Duration: 10 * time.Second}
var c *ssh.Client
var err error
err = retry.OnError(backoff, retryFunc, func() error {
c, err = client(host)
return err
})
return c, err
}
func client(host *RemoteHost) (*ssh.Client, error) {
jbConfig, err := config(host.Jumpbox.AuthConfig)
if err != nil {
return nil, errors.Wrap(err, "creating jumpbox client config")
}
jbConn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", host.Jumpbox.URI, host.Jumpbox.Port), jbConfig)
if err != nil {
return nil, errors.Wrapf(err, "dialing jumpbox (%s)", host.Jumpbox.URI)
}
hostConn, err := jbConn.Dial("tcp", fmt.Sprintf("%s:%d", host.URI, host.Port))
if err != nil {
return nil, errors.Wrapf(err, "dialing host (%s)", host.URI)
}
hostConfig, err := config(host.AuthConfig)
if err != nil {
return nil, errors.Wrap(err, "creating host client config")
}
ncc, chans, reqs, err := ssh.NewClientConn(hostConn, host.URI, hostConfig)
if err != nil {
return nil, errors.Wrapf(err, "starting new client connection to host (%s)", host.URI)
}
c, err := ssh.NewClient(ncc, chans, reqs), nil
if err != nil {
return nil, errors.Wrapf(err, "creating new ssh client for host (%s)", host.URI)
}
return c, nil
}
func config(authConfig *AuthConfig) (*ssh.ClientConfig, error) {
authMethod, err := clientConfigAuth(authConfig)
if err != nil {
return nil, err
}
hkCallback, err := knownHostsHostKeyCallback()
if err != nil {
return nil, err
}
hostKeyCallback := func(host string, remote net.Addr, pubKey ssh.PublicKey) error {
var keyErr *kh.KeyError
if cbErr := hkCallback(host, remote, pubKey); cbErr != nil && errors.As(cbErr, &keyErr) {
hostname := strings.Split(host, ":")[0]
if len(keyErr.Want) > 0 {
log.Errorf("Strict host key check failed. Remote host identification has changed.")
log.Errorf("Key '%v' does not match a key known for host %s.", hostKeyString(pubKey), hostname)
return keyErr
}
if len(keyErr.Want) == 0 {
if err := addHostKey(hostname, pubKey); err != nil {
return err
}
log.Warnf("Permanently added '%s' (%s) to the list of known hosts (%s)", hostname, pubKey.Type(), khpath)
}
}
return nil
}
return &ssh.ClientConfig{
HostKeyCallback: hostKeyCallback,
User: authConfig.User,
Auth: authMethod,
}, nil
}
// clientConfigAuth returns the ssh authentication method
func clientConfigAuth(authConfig *AuthConfig) ([]ssh.AuthMethod, error) {
// TODO we may be able to return both methods, not changing behavior for now
// TODO this can be reworked so it is executed once per command
if authConfig.PrivateKeyPath != "" {
keyAuth, err := PublicKeyAuth(authConfig.PrivateKeyPath)
if err != nil {
return nil, errors.Wrap(err, "creating public key authentication method")
}
return []ssh.AuthMethod{keyAuth}, nil
}
return []ssh.AuthMethod{ssh.Password(authConfig.Password)}, nil
}
// knownHostsHostKeyCallback returns a host key callback that uses
// a known_hosts file to store known host keys
func knownHostsHostKeyCallback() (ssh.HostKeyCallback, error) {
err := ensuresKnownHosts()
if err != nil {
return nil, err
}
khCallback, err := kh.New(khpath)
if err != nil {
return nil, errors.Wrap(err, "creating HostKeyCallback instance")
}
return khCallback, nil
}
// ensuresKnownHosts creates the known_hosts file if it does not exist
func ensuresKnownHosts() error {
if err := os.MkdirAll(path.Dir(khpath), 0700); err != nil {
return errors.Wrap(err, "creating .ssh directory")
}
f, err := os.OpenFile(khpath, os.O_CREATE, 0600)
if err != nil {
return errors.Wrap(err, "creating known_hosts file")
}
f.Close()
return nil
}
// hostKeyString pretty-prints a public key struct
func hostKeyString(k ssh.PublicKey) string {
return fmt.Sprintf("%s %s", k.Type(), base64.StdEncoding.EncodeToString(k.Marshal()))
}
// addHostKey adds an entry to the known_hosts file
func addHostKey(hostname string, pubKey ssh.PublicKey) error {
f, err := os.OpenFile(khpath, os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
return errors.Wrap(err, "opening known_hosts file")
}
defer f.Close()
// append blank line
if _, err = f.WriteString(fmt.Sprintln()); err != nil {
return errors.Wrap(err, "appending blank line to known_hosts file")
}
// append host key line
knownHosts := kh.Normalize(hostname)
_, err = f.WriteString(kh.Line([]string{knownHosts}, pubKey))
if err != nil {
return errors.Wrap(err, "writing known_hosts file")
}
return nil
}