plugins/teststeps/exec/transport/ssh_transport.go (148 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
package transport
import (
"fmt"
"io/ioutil"
"net"
"os"
"strconv"
"time"
"github.com/google/uuid"
"github.com/insomniacslk/xjson"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"github.com/facebookincubator/contest/pkg/xcontext"
)
type SSHTransportConfig struct {
Host string `json:"host,omitempty"`
Port int `json:"port,omitempty"`
User string `json:"user,omitempty"`
Password string `json:"password,omitempty"`
IdentityFile string `json:"identity_file,omitempty"`
Timeout xjson.Duration `json:"timeout,omitempty"`
SendBinary bool `json:"send_binary,omitempty"`
Async *struct {
Agent string `json:"agent,omitempty"`
TimeQuota xjson.Duration `json:"time_quota,omitempty"`
} `json:"async,omitempty"`
}
func DefaultSSHTransportConfig() SSHTransportConfig {
return SSHTransportConfig{
Port: 22,
Timeout: xjson.Duration(10 * time.Minute),
}
}
type SSHTransport struct {
SSHTransportConfig
}
func NewSSHTransport(config SSHTransportConfig) Transport {
return &SSHTransport{config}
}
func (st *SSHTransport) NewProcess(ctx xcontext.Context, bin string, args []string) (Process, error) {
var signer ssh.Signer
if st.IdentityFile != "" {
key, err := ioutil.ReadFile(st.IdentityFile)
if err != nil {
return nil, fmt.Errorf("cannot read private key at %s: %v", st.IdentityFile, err)
}
signer, err = ssh.ParsePrivateKey(key)
if err != nil {
return nil, fmt.Errorf("cannot parse private key: %v", err)
}
}
auth := []ssh.AuthMethod{}
if signer != nil {
auth = append(auth, ssh.PublicKeys(signer))
}
if st.Password != "" {
auth = append(auth, ssh.Password(st.Password))
}
addr := net.JoinHostPort(st.Host, strconv.Itoa(st.Port))
clientConfig := &ssh.ClientConfig{
User: st.User,
Auth: auth,
// TODO expose this in the plugin arguments
//HostKeyCallback: ssh.FixedHostKey(hostKey),
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: time.Duration(st.Timeout),
}
// stack mechanism similar to defer, but run after the exec process ends
stack := newDeferedStack()
client, err := ssh.Dial("tcp", addr, clientConfig)
if err != nil {
return nil, fmt.Errorf("cannot connect to SSH server %s: %v", addr, err)
}
// cleanup the ssh client after the operations have ended
stack.Add(func() {
if err := client.Close(); err != nil {
ctx.Warnf("failed to close SSH client: %w", err)
}
})
if st.SendBinary {
if err := checkBinary(bin); err != nil {
return nil, err
}
bin, err = st.sendFile(ctx, client, bin, 0500)
if err != nil {
return nil, fmt.Errorf("cannot send binary to remote ssh: %w", err)
}
// cleanup the sent file so we don't leave hanging files around
stack.Add(func() {
ctx.Debugf("cleaning remote file: %s", bin)
if err := st.unlinkFile(ctx, client, bin); err != nil {
ctx.Warnf("failed to cleanup remote file: %w", err)
}
})
}
if st.Async != nil {
return st.newAsync(ctx, client, addr, clientConfig, bin, args, stack)
}
return st.new(ctx, client, bin, args, stack)
}
func (st *SSHTransport) new(ctx xcontext.Context, client *ssh.Client, bin string, args []string, stack *deferedStack) (Process, error) {
return newSSHProcess(ctx, client, bin, args, stack)
}
func (st *SSHTransport) newAsync(
ctx xcontext.Context,
client *ssh.Client, addr string, clientConfig *ssh.ClientConfig,
bin string, args []string,
stack *deferedStack,
) (Process, error) {
// we always need the agent for the async case
agent, err := st.sendFile(ctx, client, st.Async.Agent, 0500)
if err != nil {
return nil, fmt.Errorf("failed to send agent: %w", err)
}
stack.Add(func() {
ctx.Debugf("cleaning async agent: %s", agent)
if err := st.unlinkFile(ctx, client, agent); err != nil {
ctx.Warnf("failed to cleanup asyng agent: %w", err)
}
})
return newSSHProcessAsync(ctx, addr, clientConfig, agent, st.Async.TimeQuota, bin, args, stack)
}
func (st *SSHTransport) sendFile(ctx xcontext.Context, client *ssh.Client, bin string, mode os.FileMode) (string, error) {
sftp, err := sftp.NewClient(client)
if err != nil {
return "", fmt.Errorf("failed to create sftp client: %w", err)
}
defer sftp.Close()
remoteBin := fmt.Sprintf("/tmp/exec_bin_%s", uuid.New().String())
fout, err := sftp.Create(remoteBin)
if err != nil {
return "", fmt.Errorf("failed to create sftp file: %w", err)
}
defer fout.Close()
fin, err := os.Open(bin)
if err != nil {
return "", fmt.Errorf("cannot open source bin file: %w", err)
}
defer fin.Close()
ctx.Debugf("sending file to remote: %s", remoteBin)
_, err = fout.ReadFrom(fin)
if err != nil {
return "", fmt.Errorf("failed to send file: %w", err)
}
return remoteBin, fout.Chmod(mode)
}
func (st *SSHTransport) unlinkFile(ctx xcontext.Context, client *ssh.Client, bin string) error {
sftp, err := sftp.NewClient(client)
if err != nil {
return fmt.Errorf("failed to create sftp client: %w", err)
}
defer sftp.Close()
return sftp.Remove(bin)
}