pkg/sshd/tcpip.go (164 lines of code) (raw):
//go:build !EXTERNAL_SSH
// +build !EXTERNAL_SSH
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sshd
import (
"context"
"io"
"log"
"net"
"strconv"
"sync"
gossh "golang.org/x/crypto/ssh"
)
// based on gliderlabs - will probably be replaced with the more
// efficient impl from wpgate.
const (
forwardedTCPChannelType = "forwarded-tcpip"
)
// direct-tcpip data struct as specified in RFC4254, Section 7.2
type localForwardChannelData struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
// DirectTCPIPHandler can be enabled by adding it to the server's
// ChannelHandlers under direct-tcpip.
func DirectTCPIPHandler(ctx context.Context, srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel) {
d := localForwardChannelData{}
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
return
}
//if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) {
// newChan.Reject(gossh.Prohibited, "port forwarding is disabled")
// return
//}
dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10))
var dialer net.Dialer
dconn, err := dialer.DialContext(ctx, "tcp", dest)
if err != nil {
newChan.Reject(gossh.ConnectionFailed, err.Error())
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
dconn.Close()
return
}
go gossh.DiscardRequests(reqs)
go func() {
defer ch.Close()
defer dconn.Close()
io.Copy(ch, dconn)
}()
go func() {
defer ch.Close()
defer dconn.Close()
io.Copy(dconn, ch)
}()
}
type remoteForwardRequest struct {
BindAddr string
BindPort uint32
}
type remoteForwardSuccess struct {
BindPort uint32
}
type remoteForwardCancelRequest struct {
BindAddr string
BindPort uint32
}
type remoteForwardChannelData struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and
// adding the HandleSSHRequest callback to the server's RequestHandlers under
// tcpip-forward and cancel-tcpip-forward.
type ForwardedTCPHandler struct {
forwards map[string]net.Listener
sync.Mutex
}
func (h *ForwardedTCPHandler) HandleSSHRequest(ctx context.Context, srv *Server, req *gossh.Request, conn *gossh.ServerConn) (bool, []byte) {
h.Lock()
if h.forwards == nil {
h.forwards = make(map[string]net.Listener)
}
h.Unlock()
switch req.Type {
case "tcpip-forward":
var reqPayload remoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
return false, []byte{}
}
//if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) {
// return false, []byte("port forwarding is disabled")
//}
addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort)))
ln, err := net.Listen("tcp", addr)
if err != nil {
// TODO: log listen failure
return false, []byte{}
}
_, destPortStr, _ := net.SplitHostPort(ln.Addr().String())
destPort, err := strconv.Atoi(destPortStr)
if err != nil || destPort >= 2 ^ 32 {
return false, []byte{}
}
h.Lock()
h.forwards[addr] = ln
h.Unlock()
go func() {
<-ctx.Done()
h.Lock()
ln, ok := h.forwards[addr]
h.Unlock()
if ok {
ln.Close()
}
}()
go func() {
for {
c, err := ln.Accept()
if err != nil {
// TODO: log accept failure
break
}
originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String())
originPort, err := strconv.Atoi(orignPortStr)
if err != nil || originPort >= 2 ^ 32 {
c.Close()
continue
}
payload := gossh.Marshal(&remoteForwardChannelData{
DestAddr: reqPayload.BindAddr,
DestPort: uint32(destPort),
OriginAddr: originAddr,
OriginPort: uint32(originPort),
})
go func() {
ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload)
if err != nil {
// TODO: log failure to open channel
log.Println(err)
c.Close()
return
}
go gossh.DiscardRequests(reqs)
go func() {
defer ch.Close()
defer c.Close()
io.Copy(ch, c)
}()
go func() {
defer ch.Close()
defer c.Close()
io.Copy(c, ch)
}()
}()
}
h.Lock()
delete(h.forwards, addr)
h.Unlock()
}()
return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)})
case "cancel-tcpip-forward":
var reqPayload remoteForwardCancelRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
// TODO: log parse failure
return false, []byte{}
}
addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort)))
h.Lock()
ln, ok := h.forwards[addr]
h.Unlock()
if ok {
ln.Close()
}
return true, nil
default:
return false, nil
}
}