helpers/ssh/stub_ssh_server.go (361 lines of code) (raw):

package ssh import ( "context" "encoding/binary" "errors" "fmt" "io" "net" "os" "os/exec" "path/filepath" "strconv" "sync" "sync/atomic" "syscall" "time" "gitlab.com/gitlab-org/gitlab-runner/common" "golang.org/x/crypto/ssh" ) type StubSSHServer struct { Options User string Password string Config *ssh.ServerConfig Shell []string host string port string privateKeyLocation string stopped chan struct{} tempDir string listener net.Listener once sync.Once err error closed atomic.Bool } var TestSSHKeyPair = struct { PublicKey string PrivateKey string }{ PrivateKey: `-----BEGIN OPENSSH PRIVATE KEY----- b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAlwAAAAdzc2gtcn NhAAAAAwEAAQAAAIEA2FnuhEf3bCtSe6eyg5/Ir3kzjGx3gFij1H3QmerGIzz7JW+oxVWf r+x7Ix61dZcE/8VXow4C2BFOXRNoa8KFN1gQh+jbbZTgc1sWCTyr6iKZIDoKR59W4pceTP TnAQ4RHNNJwhCTDDsYlklCRBpJ79d6nt9r5O2kbVju3/wTCUsAAAIYw8mlC8PJpQsAAAAH c3NoLXJzYQAAAIEA2FnuhEf3bCtSe6eyg5/Ir3kzjGx3gFij1H3QmerGIzz7JW+oxVWfr+ x7Ix61dZcE/8VXow4C2BFOXRNoa8KFN1gQh+jbbZTgc1sWCTyr6iKZIDoKR59W4pceTPTn AQ4RHNNJwhCTDDsYlklCRBpJ79d6nt9r5O2kbVju3/wTCUsAAAADAQABAAAAgGBufUSSuz KIgMRC8+t9Hbswv4w4kG8xkxxUU9U28sekF6ERCt2iE4IbWqtFtcXK4VyLfktcJGJgHFia HPHjCvLVKGxBqoM1beWctSIpdjlu+VJedNkaFpEKZRe7Wpx61B7an+JdZJiR87CSJxkkGE GLhuZwio6O8bBof2NEtScxAAAAQCzvxCvu+cswV+V4TYeTc/Wr7WN0J4omkwKWa0y69Z2Y 8zV2SpSoex+7mCsWQrumDCxIn+lQ7g45kdoYqAIPWZwAAABBAPRzwg8P861S4jMxnTFMUb 0izGpRrSSyrMWmhnB6do42CavG1LrS6bo0JTHVRb2uhP0OVfSWscb8C2s2oXK7FTMAAABB AOKSVxw+gKB6O9Ez6Tr732hotJVeo04HGZ3ZCQWigFabouRbR5dUntt5ElRmCFVSJW/XnZ tlxpSUh4YUnfTGi4kAAAAham9obmNhaUBKb2hucy1NYWNCb29rLVByby0zLmxvY2FsAQI= -----END OPENSSH PRIVATE KEY-----`, PublicKey: `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQDYWe6ER/dsK1J7p7KDn8iveTOMbHeAWKPUfdCZ6sYjPPslb6jFVZ+v7HsjHrV1lwT/xVejDgLYEU5dE2hrwoU3WBCH6NttlOBzWxYJPKvqIpkgOgpHn1bilx5M9OcBDhEc00nCEJMMOxiWSUJEGknv13qe32vk7aRtWO7f/BMJSw==`, } type Option func(*Options) type Options struct { DontAcceptConnections bool ExecuteLocal bool } func WithDontAcceptConnections() Option { return func(o *Options) { o.DontAcceptConnections = true } } func WithExecuteLocal() Option { return func(o *Options) { o.ExecuteLocal = true } } func NewStubServer(user, pass string, opts ...Option) (server *StubSSHServer, err error) { tempDir, err := os.MkdirTemp("", "ssh-stub-server") if err != nil { return nil, err } var options Options for _, opt := range opts { opt(&options) } server = &StubSSHServer{ Options: options, User: user, Password: pass, Config: &ssh.ServerConfig{ PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if conn.User() == user && string(password) == pass { return nil, nil } return nil, fmt.Errorf("wrong password for %q", conn.User()) }, }, stopped: make(chan struct{}), tempDir: tempDir, } privateKeyLocation := filepath.Join(tempDir, "id_rsa_test") publicKeyLocation := filepath.Join(tempDir, "id_rsa_test.pub") defer func() { if err != nil { os.RemoveAll(tempDir) } }() if err := os.WriteFile(privateKeyLocation, []byte(TestSSHKeyPair.PrivateKey), 0o600); err != nil { return nil, err } if err := os.WriteFile(publicKeyLocation, []byte(TestSSHKeyPair.PublicKey), 0o600); err != nil { return nil, err } key, err := ssh.ParsePrivateKey([]byte(TestSSHKeyPair.PrivateKey)) if err != nil { return nil, err } server.privateKeyLocation = privateKeyLocation server.Config.AddHostKey(key) if err := server.start(); err != nil { return nil, err } return server, nil } func (s *StubSSHServer) start() error { listener, err := net.Listen("tcp", "127.0.0.1:") if err != nil { return err } s.listener = listener host, port, err := net.SplitHostPort(listener.Addr().String()) if err != nil { return err } s.host = host s.port = port go s.mainLoop(listener) return err } func (s *StubSSHServer) setError(err error) { if errors.Is(err, io.EOF) { return } if err != nil { s.once.Do(func() { s.err = err }) } } func (s *StubSSHServer) Host() string { return s.host } func (s *StubSSHServer) Port() string { return s.port } func (s *StubSSHServer) Stop() error { if s.closed.Load() { return s.err } s.closed.Store(true) s.listener.Close() os.RemoveAll(s.tempDir) err := s.err // if the error is expected because we cancelled, don't return an error if errors.Is(err, context.Canceled) { err = nil } select { case <-s.stopped: return err case <-time.After(45 * time.Second): return fmt.Errorf("timed out waiting for active ssh session to close") } } //nolint:gocognit func (s *StubSSHServer) mainLoop(listener net.Listener) { defer close(s.stopped) var wg sync.WaitGroup defer wg.Wait() ctx, cancel := context.WithCancel(context.Background()) defer cancel() for { if s.closed.Load() { return } if s.DontAcceptConnections { time.Sleep(time.Second) continue } conn, err := listener.Accept() if errors.Is(err, net.ErrClosed) { return } if errors.Is(err, io.EOF) { continue } if err != nil { s.setError(err) return } _, channels, reqs, err := ssh.NewServerConn(conn, s.Config) if !s.ExecuteLocal { // existing tests rely on us just continuing without serving the SSH request if we're not executing locally continue } if err != nil { s.setError(err) return } go ssh.DiscardRequests(reqs) go func() { for channel := range channels { wg.Add(1) go func(channel ssh.NewChannel) { defer wg.Done() var err error switch channel.ChannelType() { case "session": err = s.handleSession(ctx, channel) case "direct-tcpip": var directTCPIP struct { DestAddr string DestPort uint32 LocalAddr string LocalPort uint32 } err = ssh.Unmarshal(channel.ExtraData(), &directTCPIP) if err == nil { err = s.handleProxy(ctx, "tcp", channel, net.JoinHostPort(directTCPIP.DestAddr, strconv.FormatInt(int64(directTCPIP.DestPort), 10))) } case "direct-streamlocal@openssh.com": var directStreamLocal struct { DestAddr string LocalAddr string LocalPort uint32 } err = ssh.Unmarshal(channel.ExtraData(), &directStreamLocal) if err == nil { err = s.handleProxy(ctx, "unix", channel, directStreamLocal.DestAddr) } default: err = channel.Reject(ssh.UnknownChannelType, fmt.Sprintf("%v: %v", ssh.UnknownChannelType, channel.ChannelType())) } s.setError(err) }(channel) } }() } } func (s *StubSSHServer) handleProxy(ctx context.Context, network string, channel ssh.NewChannel, addr string) error { dialer := net.Dialer{Timeout: 30 * time.Second} upstream, err := dialer.DialContext(ctx, network, addr) if err != nil { return err } defer upstream.Close() conn, _, err := channel.Accept() if err != nil { return err } defer upstream.Close() recvCh := make(chan error, 1) sendCh := make(chan error, 1) go func() { recvCh <- copier(upstream, conn, "conn to upstream") }() go func() { err := copier(conn, upstream, "upstream to conn") if errors.Is(err, syscall.ENOTCONN) || errors.Is(err, io.EOF) { err = nil } sendCh <- err }() select { case err = <-recvCh: if err != nil { return err } err = <-sendCh case err = <-sendCh: } return err } //nolint:gocognit func copier(to io.Writer, from io.Reader, desc string) (err error) { defer func() { if t, ok := from.(interface{ CloseRead() error }); ok { if cerr := t.CloseRead(); cerr != nil && err == nil { err = fmt.Errorf("close reader (%s): %w", desc, cerr) } } if t, ok := to.(interface{ CloseWrite() error }); ok { if cerr := t.CloseWrite(); cerr != nil && err == nil { err = fmt.Errorf("close writer (%s): %w", desc, cerr) } } }() if _, err := io.Copy(to, from); err != nil { return fmt.Errorf("copy (%s): %w", desc, err) } return nil } //nolint:gocognit func (s *StubSSHServer) handleSession(ctx context.Context, channel ssh.NewChannel) error { conn, reqs, err := channel.Accept() if err != nil { return err } defer conn.Close() for req := range reqs { switch req.Type { case "exec": if req.WantReply { if err := req.Reply(true, nil); err != nil { return err } } var command struct { Value []byte } if err := ssh.Unmarshal(req.Payload, &command); err != nil { return fmt.Errorf("session unmarshal: %w", err) } if ctx.Err() != nil { return ctx.Err() } if len(s.Shell) == 0 { s.Shell = []string{"sh", "-c"} } args := append(s.Shell, string(command.Value)) //nolint:gocritic cmd := exec.CommandContext(ctx, args[0], args[1:]...) cmd.Dir = s.tempDir cmd.Stdout = conn cmd.Stderr = conn cmd.Stdin = conn runErr := runCmd(cmd) if ctx.Err() != nil { return ctx.Err() } var exitError *exec.ExitError code := 0 if errors.As(runErr, &exitError) { code = exitError.ExitCode() } var exit [4]byte binary.BigEndian.PutUint32(exit[:], uint32(code)) if err := conn.CloseWrite(); err != nil { return err } if _, err := conn.SendRequest("exit-status", false, exit[:]); err != nil { return err } return runErr default: return fmt.Errorf("unknown request type: %s", req.Type) } } return nil } func (s *StubSSHServer) Client() Client { return Client{ SshConfig: common.SshConfig{ User: s.User, Password: s.Password, Host: "127.0.0.1", Port: s.port, IdentityFile: s.privateKeyLocation, }, } }