pkg/sshd/tcpip.go (164 lines of code) (raw):

//go:build !EXTERNAL_SSH // +build !EXTERNAL_SSH // Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package sshd import ( "context" "io" "log" "net" "strconv" "sync" gossh "golang.org/x/crypto/ssh" ) // based on gliderlabs - will probably be replaced with the more // efficient impl from wpgate. const ( forwardedTCPChannelType = "forwarded-tcpip" ) // direct-tcpip data struct as specified in RFC4254, Section 7.2 type localForwardChannelData struct { DestAddr string DestPort uint32 OriginAddr string OriginPort uint32 } // DirectTCPIPHandler can be enabled by adding it to the server's // ChannelHandlers under direct-tcpip. func DirectTCPIPHandler(ctx context.Context, srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel) { d := localForwardChannelData{} if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) return } //if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) { // newChan.Reject(gossh.Prohibited, "port forwarding is disabled") // return //} dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10)) var dialer net.Dialer dconn, err := dialer.DialContext(ctx, "tcp", dest) if err != nil { newChan.Reject(gossh.ConnectionFailed, err.Error()) return } ch, reqs, err := newChan.Accept() if err != nil { dconn.Close() return } go gossh.DiscardRequests(reqs) go func() { defer ch.Close() defer dconn.Close() io.Copy(ch, dconn) }() go func() { defer ch.Close() defer dconn.Close() io.Copy(dconn, ch) }() } type remoteForwardRequest struct { BindAddr string BindPort uint32 } type remoteForwardSuccess struct { BindPort uint32 } type remoteForwardCancelRequest struct { BindAddr string BindPort uint32 } type remoteForwardChannelData struct { DestAddr string DestPort uint32 OriginAddr string OriginPort uint32 } // ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and // adding the HandleSSHRequest callback to the server's RequestHandlers under // tcpip-forward and cancel-tcpip-forward. type ForwardedTCPHandler struct { forwards map[string]net.Listener sync.Mutex } func (h *ForwardedTCPHandler) HandleSSHRequest(ctx context.Context, srv *Server, req *gossh.Request, conn *gossh.ServerConn) (bool, []byte) { h.Lock() if h.forwards == nil { h.forwards = make(map[string]net.Listener) } h.Unlock() switch req.Type { case "tcpip-forward": var reqPayload remoteForwardRequest if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { // TODO: log parse failure return false, []byte{} } //if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { // return false, []byte("port forwarding is disabled") //} addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) ln, err := net.Listen("tcp", addr) if err != nil { // TODO: log listen failure return false, []byte{} } _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) destPort, err := strconv.Atoi(destPortStr) if err != nil || destPort >= 2 ^ 32 { return false, []byte{} } h.Lock() h.forwards[addr] = ln h.Unlock() go func() { <-ctx.Done() h.Lock() ln, ok := h.forwards[addr] h.Unlock() if ok { ln.Close() } }() go func() { for { c, err := ln.Accept() if err != nil { // TODO: log accept failure break } originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) originPort, err := strconv.Atoi(orignPortStr) if err != nil || originPort >= 2 ^ 32 { c.Close() continue } payload := gossh.Marshal(&remoteForwardChannelData{ DestAddr: reqPayload.BindAddr, DestPort: uint32(destPort), OriginAddr: originAddr, OriginPort: uint32(originPort), }) go func() { ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) if err != nil { // TODO: log failure to open channel log.Println(err) c.Close() return } go gossh.DiscardRequests(reqs) go func() { defer ch.Close() defer c.Close() io.Copy(ch, c) }() go func() { defer ch.Close() defer c.Close() io.Copy(c, ch) }() }() } h.Lock() delete(h.forwards, addr) h.Unlock() }() return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) case "cancel-tcpip-forward": var reqPayload remoteForwardCancelRequest if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { // TODO: log parse failure return false, []byte{} } addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) h.Lock() ln, ok := h.forwards[addr] h.Unlock() if ok { ln.Close() } return true, nil default: return false, nil } }