plugins/teststeps/exec/transport/ssh_process_async.go (255 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. package transport import ( "bufio" "bytes" "errors" "fmt" "io" "os" "time" "github.com/facebookincubator/contest/pkg/xcontext" "github.com/insomniacslk/xjson" "github.com/kballard/go-shellquote" "golang.org/x/crypto/ssh" ) type sshProcessAsync struct { addr string clientConfig *ssh.ClientConfig cmd string agent string outWriter io.WriteCloser errWriter io.WriteCloser closeOnWait []io.Closer exitChan chan error stack *deferedStack } func newSSHProcessAsync( ctx xcontext.Context, addr string, clientConfig *ssh.ClientConfig, agent string, timeQuota xjson.Duration, bin string, args []string, stack *deferedStack, ) (Process, error) { // build the command to run remotely agentArgs := []string{agent} if timeQuota != 0 { agentArgs = append(agentArgs, fmt.Sprintf("--time-quota=%s", timeQuota.String())) } agentArgs = append(agentArgs, "start", bin) agentArgs = append(agentArgs, args...) cmd := shellquote.Join(agentArgs...) exitChan := make(chan error, 1) return &sshProcessAsync{ addr: addr, clientConfig: clientConfig, cmd: cmd, agent: agent, closeOnWait: []io.Closer{}, exitChan: exitChan, stack: stack, }, nil } func (spa *sshProcessAsync) Start(ctx xcontext.Context) error { errChan := make(chan error, 1) resChan := make(chan string, 1) go func() { // NOTE: golang doesnt support forking, so the started process needs to be // forcefully detached by closing the ssh session; detach is defered here client, err := ssh.Dial("tcp", spa.addr, spa.clientConfig) if err != nil { errChan <- fmt.Errorf("cannot connect to SSH server %s: %v", spa.addr, err) return } defer client.Close() session, err := client.NewSession() if err != nil { errChan <- fmt.Errorf("cannot create SSH session to server: %v", err) return } defer session.Close() stdout, err := session.StdoutPipe() if err != nil { errChan <- fmt.Errorf("failed to get stdout pipe") return } ctx.Debugf("starting remote agent: %s", spa.cmd) if err := session.Start(spa.cmd); err != nil { errChan <- fmt.Errorf("failed to start process: %w", err) return } // read the session id that the agent will put on stdout s := bufio.NewScanner(stdout) if !s.Scan() { errChan <- fmt.Errorf("agent did not return a session id") return } resChan <- s.Text() }() select { case err := <-errChan: return err case sid := <-resChan: ctx.Debugf("remote sid: %s", sid) outWriter := spa.outWriter if outWriter == nil { var err error outWriter, err = os.OpenFile(os.DevNull, os.O_WRONLY, 0) if err != nil { return err } } errWriter := spa.errWriter if errWriter == nil { var err error errWriter, err = os.OpenFile(os.DevNull, os.O_WRONLY, 0) if err != nil { return err } } mon := &asyncMonitor{spa.addr, spa.clientConfig, spa.agent, sid} go mon.Start(ctx, outWriter, errWriter, spa.exitChan) return nil case <-time.After(5 * time.Second): return fmt.Errorf("timeout while starting agent") case <-ctx.Done(): return ctx.Err() } } func (spa *sshProcessAsync) Wait(_ xcontext.Context) error { defer spa.stack.Done() // wait for process err := <-spa.exitChan var e *ssh.ExitError if errors.As(err, &e) { return fmt.Errorf("process exited with error: %w", e) } if err != nil { return fmt.Errorf("failed to wait on process: %w", err) } return nil } func (spa *sshProcessAsync) StdoutPipe() (io.Reader, error) { r, w := io.Pipe() spa.outWriter = w spa.closeOnWait = append(spa.closeOnWait, r) return r, nil } func (spa *sshProcessAsync) StderrPipe() (io.Reader, error) { r, w := io.Pipe() spa.errWriter = w spa.closeOnWait = append(spa.closeOnWait, r) return r, nil } func (spa *sshProcessAsync) String() string { return spa.cmd } // TODO: maybe extract this to a package? const ( ProcessFinishedExitCode = 13 DeadAgentExitCode = 14 ) type asyncMonitor struct { addr string clientConfig *ssh.ClientConfig agent string sid string } func (m *asyncMonitor) Start( ctx xcontext.Context, outWriter io.WriteCloser, errWriter io.WriteCloser, exitChan chan<- error, ) { defer outWriter.Close() defer errWriter.Close() for { select { case <-time.After(time.Second): ctx.Debugf("polling remote process: %s", m.sid) stdout, stderr, err, runerr := m.runAgent(ctx, "poll") if err != nil { ctx.Warnf("failed to run agent: %w", err) continue } // append stdout, stderr; blocking until read if _, err := outWriter.Write(stdout); err != nil { ctx.Warnf("failed to write to stdout pipe: %w", err) continue } if _, err := errWriter.Write(stderr); err != nil { ctx.Warnf("failed to write to stderr pipe: %w", err) continue } if runerr != nil { var em *ssh.ExitMissingError if errors.As(runerr, &em) { if err := m.reap(ctx); err != nil { ctx.Warnf("monitor error: %w", err) } // process exited without an error or signal; this is a ssh server error exitChan <- fmt.Errorf("internal ssh server error: %w", em) return } var ee *ssh.ExitError if errors.As(runerr, &ee) { if err := m.reap(ctx); err != nil { ctx.Warnf("monitor error: %w", err) } switch ee.ExitStatus() { case ProcessFinishedExitCode: // agent controlled process exited by itself exitChan <- nil case DeadAgentExitCode: // agent killed itself due to time quota or other error exitChan <- fmt.Errorf("agent exceeded time quota or just crashed") default: exitChan <- ee } return } // process is done, but there's some other internal error exitChan <- runerr } case <-ctx.Done(): ctx.Debugf("killing remote process, reason: cancellation") err := m.kill(ctx) if err := m.reap(ctx); err != nil { ctx.Warnf("monitor error: %w", err) } exitChan <- err return } } } func (m *asyncMonitor) kill(ctx xcontext.Context) error { ctx.Debugf("killing remote process: %s", m.sid) _, _, err, runerr := m.runAgent(ctx, "kill") if err != nil { return fmt.Errorf("failed to start agent kill: %w", err) } if runerr != nil { // note: this should never happen return fmt.Errorf("failed to kill remote process: %w", runerr) } return nil } func (m *asyncMonitor) reap(ctx xcontext.Context) error { ctx.Debugf("reaping remote process: %s", m.sid) _, _, err, runerr := m.runAgent(ctx, "wait") if err != nil { return fmt.Errorf("failed to start agent reap: %w", err) } if runerr != nil { return fmt.Errorf("failed to reap remote process: %w", runerr) } return nil } func (m *asyncMonitor) runAgent(ctx xcontext.Context, verb string) ([]byte, []byte, error, error) { client, err := ssh.Dial("tcp", m.addr, m.clientConfig) if err != nil { return nil, nil, fmt.Errorf("cannot connect to SSH server %s: %v", m.addr, err), nil } defer client.Close() session, err := client.NewSession() if err != nil { return nil, nil, fmt.Errorf("cannot create SSH session to server: %w", err), nil } defer session.Close() var stdout, stderr bytes.Buffer session.Stdout = &stdout session.Stderr = &stderr cmd := shellquote.Join(m.agent, verb, m.sid) ctx.Debugf("starting agent command: %s", cmd) if err := session.Start(cmd); err != nil { return nil, nil, fmt.Errorf("failed to start remote agent: %w", err), nil } // note: dont move this to the return line because stdout will be empty runerr := session.Wait() return stdout.Bytes(), stderr.Bytes(), nil, runerr }