plugins/teststeps/sshcmd/sshcmd.go (256 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. package sshcmd // The SSHCmd plugin implements an SSH command executor step. Only PublicKey // and Password authentication are supported. GSSAPI not supported yet. // // Warning: this plugin does not lock password and keys in memory, and does no // safe erase in memory to avoid forensic attacks. If you need that, please // submit a PR. // // Warning: commands are interpreted, so be careful with external input in the // test step arguments. import ( "bytes" "encoding/json" "errors" "fmt" "io" "io/ioutil" "net" "regexp" "strconv" "time" "github.com/kballard/go-shellquote" "golang.org/x/crypto/ssh" "github.com/facebookincubator/contest/pkg/event" "github.com/facebookincubator/contest/pkg/event/testevent" "github.com/facebookincubator/contest/pkg/target" "github.com/facebookincubator/contest/pkg/test" "github.com/facebookincubator/contest/pkg/xcontext" "github.com/facebookincubator/contest/plugins/teststeps" ) // Name is the name used to look this plugin up. var Name = "SSHCmd" // Events is used by the framework to determine which events this plugin will // emit. Any emitted event that is not registered here will cause the plugin to // fail. var Events = []event.Name{} const defaultSSHPort = 22 const defaultTimeoutParameter = "10m" // SSHCmd is used to run arbitrary commands as test steps. type SSHCmd struct { Host *test.Param Port *test.Param User *test.Param PrivateKeyFile *test.Param Password *test.Param Executable *test.Param Args []test.Param Expect *test.Param Timeout *test.Param SkipIfEmptyHost *test.Param } // Name returns the plugin name. func (ts SSHCmd) Name() string { return Name } // Run executes the cmd step. func (ts *SSHCmd) Run(ctx xcontext.Context, ch test.TestStepChannels, params test.TestStepParameters, ev testevent.Emitter, resumeState json.RawMessage) (json.RawMessage, error) { log := ctx.Logger() // XXX: Dragons ahead! The target (%t) substitution, and function // expression evaluations are done at run-time, so they may still fail // despite passing at early validation time. // If the function evaluations called in validateAndPopulate are not idempotent, // the output of the function expressions may be different (e.g. with a call to a // backend or a random pool of results) // Function evaluation could be done at validation time, but target // substitution cannot, because the targets are not known at that time. if err := ts.validateAndPopulate(params); err != nil { return nil, err } f := func(ctx xcontext.Context, target *target.Target) error { // apply filters and substitutions to user, host, private key, and command args user, err := ts.User.Expand(target) if err != nil { return fmt.Errorf("cannot expand user parameter: %v", err) } host, err := ts.Host.Expand(target) if err != nil { return fmt.Errorf("cannot expand host parameter: %v", err) } if len(host) == 0 { shouldSkip := false if !ts.SkipIfEmptyHost.IsEmpty() { var err error shouldSkip, err = strconv.ParseBool(ts.SkipIfEmptyHost.String()) if err != nil { return fmt.Errorf("cannot expand 'skip_if_empty_host' parameter value '%s': %w", ts.SkipIfEmptyHost, err) } } if shouldSkip { return nil } else { return fmt.Errorf("host value is empty") } } portStr, err := ts.Port.Expand(target) if err != nil { return fmt.Errorf("cannot expand port parameter: %v", err) } port, err := strconv.Atoi(portStr) if err != nil { return fmt.Errorf("failed to convert port parameter to integer: %v", err) } timeoutStr, err := ts.Timeout.Expand(target) if err != nil { return fmt.Errorf("cannot expand timeout parameter %s: %v", timeoutStr, err) } timeout, err := time.ParseDuration(timeoutStr) if err != nil { return fmt.Errorf("cannot parse timeout paramter: %v", err) } timeTimeout := time.Now().Add(timeout) // apply functions to the private key, if any var signer ssh.Signer privKeyFile, err := ts.PrivateKeyFile.Expand(target) if err != nil { return fmt.Errorf("cannot expand private key file parameter: %v", err) } if privKeyFile != "" { key, err := ioutil.ReadFile(privKeyFile) if err != nil { return fmt.Errorf("cannot read private key at %s: %v", ts.PrivateKeyFile, err) } signer, err = ssh.ParsePrivateKey(key) if err != nil { return fmt.Errorf("cannot parse private key: %v", err) } } password, err := ts.Password.Expand(target) if err != nil { return fmt.Errorf("cannot expand password parameter: %v", err) } auth := []ssh.AuthMethod{} if signer != nil { auth = append(auth, ssh.PublicKeys(signer)) } if password != "" { auth = append(auth, ssh.Password(password)) } config := ssh.ClientConfig{ User: user, Auth: auth, // TODO expose this in the plugin arguments //HostKeyCallback: ssh.FixedHostKey(hostKey), HostKeyCallback: ssh.InsecureIgnoreHostKey(), } executable, err := ts.Executable.Expand(target) if err != nil { return fmt.Errorf("cannot expand executable parameter: %v", err) } // apply functions to the command args, if any var args []string for _, arg := range ts.Args { earg, err := arg.Expand(target) if err != nil { return fmt.Errorf("cannot expand command argument '%s': %v", arg, err) } args = append(args, earg) } // connect to the host addr := net.JoinHostPort(host, strconv.Itoa(port)) client, err := ssh.Dial("tcp", addr, &config) if err != nil { return fmt.Errorf("cannot connect to SSH server %s: %v", addr, err) } defer func() { if err := client.Close(); err != nil { ctx.Warnf("Failed to close SSH connection to %s: %v", addr, err) } }() session, err := client.NewSession() if err != nil { return fmt.Errorf("cannot create SSH session to server %s: %v", addr, err) } defer func() { if err := session.Close(); err != nil && err != io.EOF { ctx.Warnf("Failed to close SSH session to %s: %v", addr, err) } }() // run the remote command and catch stdout/stderr var stdout, stderr bytes.Buffer session.Stdout, session.Stderr = &stdout, &stderr cmd := shellquote.Join(append([]string{executable}, args...)...) log.Debugf("Running remote SSH command on %s: '%v'", addr, cmd) errCh := make(chan error, 1) go func() { innerErr := session.Run(cmd) errCh <- innerErr }() expect := ts.Expect.String() re, err := regexp.Compile(expect) keepAliveCnt := 0 if err != nil { return fmt.Errorf("malformed expect parameter: Can not compile %s with %v", expect, err) } for { select { case err := <-errCh: log.Infof("Stdout of command '%s' is '%s'", cmd, stdout.Bytes()) if err == nil { // Execute expectations if expect == "" { ctx.Warnf("no expectations specified") } else { matches := re.FindAll(stdout.Bytes(), -1) if len(matches) > 0 { log.Infof("match for regex '%s' found", expect) } else { return fmt.Errorf("match for %s not found for target %v", expect, target) } } } else { ctx.Warnf("Stderr of command '%s' is '%s'", cmd, stderr.Bytes()) } return err case <-ctx.Done(): return session.Signal(ssh.SIGKILL) case <-time.After(250 * time.Millisecond): keepAliveCnt++ if expect != "" { matches := re.FindAll(stdout.Bytes(), -1) if len(matches) > 0 { log.Infof("match for regex '%s' found", expect) return nil } } if time.Now().After(timeTimeout) { return fmt.Errorf("timed out after %s", timeout) } // This is needed to keep the connection to the server alive if keepAliveCnt%20 == 0 { err = session.Signal(ssh.Signal("CONT")) if err != nil { log.Warnf("Unable to send CONT to ssh server: %v", err) } } } } } return teststeps.ForEachTarget(Name, ctx, ch, f) } func (ts *SSHCmd) validateAndPopulate(params test.TestStepParameters) error { var err error ts.Host = params.GetOne("host") if ts.Host.IsEmpty() { return errors.New("invalid or missing 'host' parameter, must be exactly one string") } if params.GetOne("port").IsEmpty() { ts.Port = test.NewParam(strconv.Itoa(defaultSSHPort)) } else { var port int64 port, err = params.GetInt("port") if err != nil { return fmt.Errorf("invalid 'port' parameter, not an integer: %v", err) } if port < 0 || port > 0xffff { return fmt.Errorf("invalid 'port' parameter: not in range 0-65535") } } ts.User = params.GetOne("user") if ts.User.IsEmpty() { return errors.New("invalid or missing 'user' parameter, must be exactly one string") } ts.PrivateKeyFile = params.GetOne("private_key_file") // do not fail if key file is empty, in such case it won't be used ts.PrivateKeyFile = params.GetOne("private_key_file") // do not fail if password is empty, in such case it won't be used ts.Password = params.GetOne("password") ts.Executable = params.GetOne("executable") if ts.Executable.IsEmpty() { return errors.New("invalid or missing 'executable' parameter, must be exactly one string") } ts.Args = params.Get("args") ts.Expect = params.GetOne("expect") if params.GetOne("timeout").IsEmpty() { ts.Timeout = test.NewParam(defaultTimeoutParameter) } else { ts.Timeout = params.GetOne("timeout") } ts.SkipIfEmptyHost = params.GetOne("skip_if_empty_host") return nil } // ValidateParameters validates the parameters associated to the TestStep func (ts *SSHCmd) ValidateParameters(ctx xcontext.Context, params test.TestStepParameters) error { ctx.Debugf("Params %+v", params) return ts.validateAndPopulate(params) } // New initializes and returns a new SSHCmd test step. func New() test.TestStep { return &SSHCmd{} } // Load returns the name, factory and events which are needed to register the step. func Load() (string, test.TestStepFactory, []event.Name) { return Name, New, Events }