pkg/sshd/ssh_exec.go (438 lines of code) (raw):
//go:build !EXTERNAL_SSH
// +build !EXTERNAL_SSH
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sshd
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"os/exec"
"strings"
"sync"
"syscall"
"time"
"unsafe"
"github.com/creack/pty"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
// Inspired from okteto code: https://raw.githubusercontent.com/okteto/remote/main/pkg/ssh/ssh.go
// Removed deps on logger, integrated with ugate.
// Handles PTY/noPTY shell sessions and sftp.
// gliderlabs: current version doesn't work with certs. config() method requires a PublicKeyHandler, which
// doesn't have a reference to the conn ( because gliderlabs decided to invent it's 'better' interface ).
// In general the interface and abstractions are too complex and not needed.
var (
idleTimeout = 60 * time.Second
// ErrEOF is the error when the terminal exits
ErrEOF = errors.New("EOF")
)
func getExitStatusFromError(err error) int {
if err == nil {
return 0
}
exitErr, ok := err.(*exec.ExitError)
if !ok {
return 1
}
waitStatus, ok := exitErr.Sys().(syscall.WaitStatus)
if !ok {
if exitErr.Success() {
return 0
}
return 1
}
return waitStatus.ExitStatus()
}
func setWinsize(f *os.File, w, h int) {
syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
}
func handlePTY(cmd *exec.Cmd, s *session, ptyReq Pty, winCh <-chan Window) error {
if len(ptyReq.Term) > 0 {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
}
f, err := pty.Start(cmd)
if err != nil {
log.Println("failed to start pty session", err)
return err
}
go func() {
for win := range winCh {
setWinsize(f, win.Width, win.Height)
}
}()
go func() {
io.Copy(f, s) // stdin
}()
waitCh := make(chan struct{})
go func() {
defer close(waitCh)
io.Copy(s, f) // stdout
}()
if err := cmd.Wait(); err != nil {
log.Println("pty command failed while waiting", err)
return err
}
select {
case <-waitCh:
log.Println("stdout finished")
case <-time.NewTicker(1 * time.Second).C:
log.Println("stdout didn't finish after 1s")
}
return nil
}
func sendErrAndExit(s *session, err error) {
msg := strings.TrimPrefix(err.Error(), "exec: ")
if _, err := s.Stderr().Write([]byte(msg)); err != nil {
log.Println("failed to write error back to session", err)
}
if err := s.Exit(getExitStatusFromError(err)); err != nil {
log.Println(err, "pty session failed to exit")
}
}
func handleNoTTY(cmd *exec.Cmd, s *session) error {
stdout, err := cmd.StdoutPipe()
if err != nil {
log.Println(err, "couldn't get StdoutPipe")
return err
}
stderr, err := cmd.StderrPipe()
if err != nil {
log.Println(err, "couldn't get StderrPipe")
return err
}
stdin, err := cmd.StdinPipe()
if err != nil {
log.Println(err, "couldn't get StdinPipe")
return err
}
if err = cmd.Start(); err != nil {
log.Println(err, "couldn't start command '%s'", cmd.String())
return err
}
go func() {
defer stdin.Close()
if _, err := io.Copy(stdin, s); err != nil {
log.Println(err, "failed to write session to stdin.")
}
}()
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
if _, err := io.Copy(s, stdout); err != nil {
log.Println(err, "failed to write stdout to session.")
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if _, err := io.Copy(s.Stderr(), stderr); err != nil {
log.Println(err, "failed to write stderr to session.")
}
}()
wg.Wait()
if err := cmd.Wait(); err != nil {
log.Println(err, "command failed while waiting")
return err
}
return nil
}
type Signal string
// Window represents the size of a PTY window.
type Window struct {
Width int
Height int
}
// Pty represents a PTY request and configuration.
type Pty struct {
Term string
Window Window
// HELP WANTED: terminal modes!
}
type session struct {
sync.Mutex
ssh.Channel
conn *ssh.ServerConn
//handler Handler
//subsystemHandlers map[string]SubsystemHandler
srv *Server
handled bool
exited bool
pty *Pty
winch chan Window
env []string
//ptyCb PtyCallback
rawCmd string
subsystem string
sigCh chan<- Signal
sigBuf []Signal
breakCh chan<- bool
}
func (sess *session) Write(p []byte) (n int, err error) {
if sess.pty != nil {
m := len(p)
// normalize \n to \r\n when pty is accepted.
// this is a hardcoded shortcut since we don't support terminal modes.
p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1)
p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1)
n, err = sess.Channel.Write(p)
if n > m {
n = m
}
return
}
return sess.Channel.Write(p)
}
func (sess *session) Exit(code int) error {
sess.Lock()
defer sess.Unlock()
if sess.exited {
return errors.New("Session.Exit called multiple times")
}
sess.exited = true
status := struct{ Status uint32 }{uint32(code)}
_, err := sess.SendRequest("exit-status", false, ssh.Marshal(&status))
if err != nil {
return err
}
return sess.Close()
}
func (sess *session) Pty() (Pty, <-chan Window, bool) {
if sess.pty != nil {
return *sess.pty, sess.winch, true
}
return Pty{}, sess.winch, false
}
func (sess *session) Signals(c chan<- Signal) {
sess.Lock()
defer sess.Unlock()
sess.sigCh = c
if len(sess.sigBuf) > 0 {
go func() {
for _, sig := range sess.sigBuf {
sess.sigCh <- sig
}
}()
}
}
func (sess *session) Break(c chan<- bool) {
sess.Lock()
defer sess.Unlock()
sess.breakCh = c
}
const maxSigBufSize = 128
func (sess *session) handleRequests(reqs <-chan *ssh.Request) {
for req := range reqs {
switch req.Type {
case "shell", "exec":
if sess.handled {
req.Reply(false, nil)
continue
}
var payload = struct{ Value string }{}
ssh.Unmarshal(req.Payload, &payload)
sess.rawCmd = payload.Value
//// If there's a session policy callback, we need to confirm before
//// accepting the session.
//if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) {
// sess.rawCmd = ""
// req.Reply(false, nil)
// continue
//}
sess.handled = true
req.Reply(true, nil)
go func() {
sess.srv.connectionHandler(sess)
sess.Exit(0)
}()
case "subsystem":
if sess.handled {
req.Reply(false, nil)
continue
}
var payload = struct{ Value string }{}
ssh.Unmarshal(req.Payload, &payload)
sess.subsystem = payload.Value
//// If there's a session policy callback, we need to confirm before
//// accepting the session.
//if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) {
// sess.rawCmd = ""
// req.Reply(false, nil)
// continue
//}
if "sftp" == payload.Value {
sess.handled = true
req.Reply(true, nil)
go func() {
sftpHandler(sess)
sess.Exit(0)
}()
} else {
req.Reply(false, nil)
continue
}
case "env":
if sess.handled {
req.Reply(false, nil)
continue
}
var kv struct{ Key, Value string }
ssh.Unmarshal(req.Payload, &kv)
sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value))
req.Reply(true, nil)
case "signal":
var payload struct{ Signal string }
ssh.Unmarshal(req.Payload, &payload)
sess.Lock()
if sess.sigCh != nil {
sess.sigCh <- Signal(payload.Signal)
} else {
if len(sess.sigBuf) < maxSigBufSize {
sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal))
}
}
sess.Unlock()
case "pty-req":
if sess.handled || sess.pty != nil {
req.Reply(false, nil)
continue
}
ptyReq, ok := parsePtyRequest(req.Payload)
if !ok {
req.Reply(false, nil)
continue
}
//if sess.ptyCb != nil {
// ok := sess.ptyCb(sess.ctx, ptyReq)
// if !ok {
// req.Reply(false, nil)
// continue
// }
//}
sess.pty = &ptyReq
sess.winch = make(chan Window, 1)
sess.winch <- ptyReq.Window
defer func() {
// when reqs is closed
close(sess.winch)
}()
req.Reply(ok, nil)
case "window-change":
if sess.pty == nil {
req.Reply(false, nil)
continue
}
win, ok := parseWinchRequest(req.Payload)
if ok {
sess.pty.Window = win
sess.winch <- win
}
req.Reply(ok, nil)
//case agentRequestType:
// // TODO: option/callback to allow agent forwarding
// SetAgentRequested(sess.ctx)
// req.Reply(true, nil)
case "break":
ok := false
sess.Lock()
if sess.breakCh != nil {
sess.breakCh <- true
ok = true
}
req.Reply(ok, nil)
sess.Unlock()
default:
// TODO: debug log
req.Reply(false, nil)
}
}
}
func (srv *Server) connectionHandler(s *session) {
defer func() {
s.Close()
log.Println("session closed")
}()
log.Printf("starting ssh session with command '%+v'", s.rawCmd)
cmd := srv.buildCmd(s)
//if ssh.AgentRequested(s) {
// log.Println("agent requested")
// l, err := ssh.NewAgentListener()
// if err != nil {
// log.Println("failed to start agent", err)
// return
// }
//
// defer l.Close()
// go ssh.ForwardAgentConnections(l, s)
// cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String()))
//}
ptyReq, winCh, isPty := s.Pty()
if isPty {
log.Println("handling PTY session")
if err := handlePTY(cmd, s, ptyReq, winCh); err != nil {
sendErrAndExit(s, err)
return
}
s.Exit(0)
return
}
log.Println("handling non PTY session")
if err := handleNoTTY(cmd, s); err != nil {
sendErrAndExit(s, err)
return
}
s.Exit(0)
}
func sftpHandler(sess *session) {
debugStream := ioutil.Discard
serverOptions := []sftp.ServerOption{
sftp.WithDebug(debugStream),
}
server, err := sftp.NewServer(
sess,
serverOptions...,
)
if err != nil {
log.Printf("sftp server init error: %s\n", err)
return
}
if err := server.Serve(); err == io.EOF {
server.Close()
log.Println("sftp client exited session.")
} else if err != nil {
log.Println("sftp server completed with error:", err)
}
}
func (srv Server) buildCmd(s *session) *exec.Cmd {
var cmd *exec.Cmd
cmdArgs := strings.Split(s.rawCmd, " ")
if srv.Shell == "" {
if len(cmdArgs) == 1 {
cmd = exec.Command(cmdArgs[0])
} else {
cmd = exec.Command(cmdArgs[0], cmdArgs[1:]...)
}
} else {
if len(s.rawCmd) == 0 {
cmd = exec.Command(srv.Shell)
} else {
args := []string{"-c", s.rawCmd}
cmd = exec.Command(srv.Shell, args...)
}
}
cmd.Env = append(cmd.Env, os.Environ()...)
cmd.Env = append(cmd.Env, s.env...)
//fmt.Println(cmd.String())
return cmd
}
func parsePtyRequest(s []byte) (pty Pty, ok bool) {
term, s, ok := parseString(s)
if !ok {
return
}
width32, s, ok := parseUint32(s)
if !ok {
return
}
height32, _, ok := parseUint32(s)
if !ok {
return
}
pty = Pty{
Term: term,
Window: Window{
Width: int(width32),
Height: int(height32),
},
}
return
}
func parseWinchRequest(s []byte) (win Window, ok bool) {
width32, s, ok := parseUint32(s)
if width32 < 1 {
ok = false
}
if !ok {
return
}
height32, _, ok := parseUint32(s)
if height32 < 1 {
ok = false
}
if !ok {
return
}
win = Window{
Width: int(width32),
Height: int(height32),
}
return
}
func parseString(in []byte) (out string, rest []byte, ok bool) {
if len(in) < 4 {
return
}
length := binary.BigEndian.Uint32(in)
if uint32(len(in)) < 4+length {
return
}
out = string(in[4 : 4+length])
rest = in[4+length:]
ok = true
return
}
func parseUint32(in []byte) (uint32, []byte, bool) {
if len(in) < 4 {
return 0, nil, false
}
return binary.BigEndian.Uint32(in), in[4:], true
}