pkg/hbone/io.go (272 lines of code) (raw):
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package hbone
import (
"context"
"crypto/tls"
"io"
"log"
"net"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
// TODO: benchmark different sizes.
var bufSize = 32 * 1024
var Debug = false
var (
// createBuffer to get a buffer. io.Copy uses 32k.
// experimental use shows ~20k max read with Firefox.
bufferPoolCopy = sync.Pool{New: func() interface{} {
return make([]byte, 0, 32*1024)
}}
)
// CloseWriter is one of possible interfaces implemented by Out to send a FIN, without closing
// the input. Some writers only do this when Close is called.
type CloseWriter interface {
CloseWrite() error
}
var streamIDs int64 = 0
type Stream struct {
Written int64
Err error
InError bool
Src io.Reader
Dst io.Writer
ID string
}
func proxy(ctx context.Context, cin io.Reader, cout io.WriteCloser, sin io.Reader, sout io.WriteCloser) error {
ch := make(chan int)
s1 := Stream{
ID: "client-o",
Dst: sout,
Src: cin,
}
go s1.CopyBuffered(ch, true)
s2 := Stream{
ID: "client-i",
Dst: cout,
Src: sin,
}
s2.CopyBuffered(nil, true)
<-ch
if s1.Err != nil {
return s1.Err
}
return s2.Err
}
// CopyBuffered will copy src to dst, using a pooled intermediary buffer.
//
// Blocking, returns when src returned an error or EOF/graceful close.
// May also return with error if src or dst return errors.
//
// CopyBuffered may be called in a go routine, for one of the streams in the
// connection - the stats and error are returned on a channel.
func (s Stream) CopyBuffered(ch chan int, close bool) {
buf1 := bufferPoolCopy.Get().([]byte)
defer bufferPoolCopy.Put(buf1)
bufCap := cap(buf1)
buf := buf1[0:bufCap:bufCap]
//st := Stream{}
// For netstack: src is a gonet.Conn, doesn't implement WriterTo. Dst is a net.TcpConn - and implements ReadFrom.
// CopyBuffered is the actual implementation of Copy and CopyBuffer.
// if buf is nil, one is allocated.
// Duplicated from io
// This will prevent stats from working.
// If the reader has a WriteTo method, use it to do the copy.
// Avoids an allocation and a copy.
//if wt, ok := src.(io.WriterTo); ok {
// return wt.WriteTo(dst)
//}
// Similarly, if the writer has a ReadFrom method, use it to do the copy.
//if rt, ok := dst.(io.ReaderFrom); ok {
// return rt.ReadFrom(src)
//}
if ch != nil {
defer func() {
ch <- int(0)
}()
}
if s.ID == "" {
s.ID = strconv.Itoa(int(atomic.AddInt64(&streamIDs, 1)))
}
if Debug {
log.Println(s.ID, "startCopy()")
}
for {
if srcc, ok := s.Src.(net.Conn); ok {
srcc.SetReadDeadline(time.Now().Add(15 * time.Minute))
}
nr, er := s.Src.Read(buf)
if Debug {
log.Println(s.ID, "read()", nr, er)
}
if nr > 0 { // before dealing with the read error
nw, ew := s.Dst.Write(buf[0:nr])
if Debug {
log.Println(s.ID, "write()", nw, ew)
}
if nw > 0 {
s.Written += int64(nw)
}
if f, ok := s.Dst.(http.Flusher); ok {
f.Flush()
}
if nr != nw { // Should not happen
ew = io.ErrShortWrite
if Debug {
log.Println(s.ID, "write error - short write", s.Err)
}
}
if ew != nil {
s.Err = ew
return
}
}
if er != nil {
if strings.Contains(er.Error(), "NetworkIdleTimeout") {
er = io.EOF
}
if er == io.EOF {
if Debug {
log.Println(s.ID, "done()")
}
} else {
s.Err = er
s.InError = true
if Debug {
log.Println(s.ID, "readError()", s.Err)
}
}
if close {
// read is already closed - we need to close out
closeWriter(s.Dst)
}
return
}
}
}
func closeWriter(dst io.Writer) error {
if cw, ok := dst.(CloseWriter); ok {
return cw.CloseWrite()
}
if c, ok := dst.(io.Closer); ok {
return c.Close()
}
if rw, ok := dst.(http.ResponseWriter); ok {
// Server side HTTP stream. For client side, FIN can be sent by closing the pipe (or
// request body). For server, the FIN will be sent when the handler returns - but
// this only happen after request is completed and body has been read. If server wants
// to send FIN first - while still reading the body - we are in trouble.
// That means HTTP2 TCP servers provide no way to send a FIN from server, without
// having the request fully read.
// This works for H2 with the current library - but very tricky, if not set as trailer.
rw.Header().Set("X-Close", "0")
rw.(http.Flusher).Flush()
return nil
}
log.Println("Server out not Closer nor CloseWriter nor ResponseWriter", dst)
return nil
}
// HTTPConn wraps a http server request/response in a net.Conn
type HTTPConn struct {
r io.Reader
w io.Writer
acceptedConn net.Conn
}
func (hc *HTTPConn) Read(b []byte) (n int, err error) {
return hc.r.Read(b)
}
// Write wraps the writer, which can be a http.ResponseWriter.
// Will make sure Flush() is called - normal http is buffering.
func (hc *HTTPConn) Write(b []byte) (n int, err error) {
n, err = hc.w.Write(b)
if f, ok := hc.w.(http.Flusher); ok {
f.Flush()
}
return
}
func (hc *HTTPConn) Close() error {
// TODO: close write
if cw, ok := hc.w.(CloseWriter); ok {
return cw.CloseWrite()
}
log.Println("Unexpected writer not implement CloseWriter")
return nil
}
func (hc *HTTPConn) LocalAddr() net.Addr {
return hc.acceptedConn.LocalAddr()
}
func (hc *HTTPConn) RemoteAddr() net.Addr {
return hc.acceptedConn.RemoteAddr()
}
func (hc *HTTPConn) SetDeadline(t time.Time) error {
return nil
}
func (hc *HTTPConn) SetReadDeadline(t time.Time) error {
return nil
}
func (hc *HTTPConn) SetWriteDeadline(t time.Time) error {
return nil
}
type tlsHandshakeTimeoutError struct{}
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
// HandshakeTimeout wraps tlsConn.Handshake with a timeout, to prevent hanging connection.
func HandshakeTimeout(tlsConn *tls.Conn, d time.Duration, plainConn net.Conn) error {
errc := make(chan error, 2)
var timer *time.Timer // for canceling TLS handshake
if d == 0 {
d = 3 * time.Second
}
timer = time.AfterFunc(d, func() {
errc <- tlsHandshakeTimeoutError{}
})
go func() {
err := tlsConn.Handshake()
if timer != nil {
timer.Stop()
}
errc <- err
}()
if err := <-errc; err != nil {
if plainConn != nil {
plainConn.Close()
} else {
tlsConn.Close()
}
return err
}
return nil
}
func ListenAndServeTCP(addr string, f func(conn net.Conn)) (net.Listener, error) {
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
go ServeListener(listener, f)
return listener, nil
}
func ServeListener(l net.Listener, f func(conn net.Conn)) error {
for {
remoteConn, err := l.Accept()
if err != nil {
if ne, ok := err.(interface {
Temporary() bool
}); ok && ne.Temporary() {
time.Sleep(100 * time.Millisecond)
continue
}
// TODO: callback to notify. This may happen if interface restarts, etc.
log.Println("Accepted done ", l)
return err
}
go f(remoteConn)
}
}
// BufferReader wraps a buffer and a Reader.
// The Fill method will populate the buffer.
// Read will first return data from the buffer, and if buffer is empty will
// read directly from the source reader.
type BufferReader struct {
buf []byte
roff, rend int
Reader io.Reader
}
func NewBufferReader(in io.Reader) *BufferReader {
buf1 := bufferPoolCopy.Get().([]byte)
return &BufferReader{buf: buf1, Reader: in}
}
func (s *BufferReader) Fill(i int) ([]byte, error) {
if s.rend >= i {
return s.buf[0:s.rend], nil
}
for {
n, err := s.Reader.Read(s.buf[s.rend:cap(s.buf)])
s.rend += n
if err != nil {
return s.buf[0:s.rend], err
}
if s.rend >= i {
return s.buf[0:s.rend], nil
}
}
}
// Read will first return the buffered data, then read.
// For SNI routing we don't actually need this - in is a TcpConn and
// we'll use in.ReadFrom to take advantage of splice.
func (s *BufferReader) Read(d []byte) (int, error) {
if s.rend > 0 {
bn := copy(d, s.buf[s.roff:s.rend])
s.roff += bn
if s.roff == s.rend {
s.rend = 0
}
return bn, nil
}
return s.Reader.Read(d)
}
func (s *BufferReader) Close() error {
if s.buf != nil {
bufferPoolCopy.Put(s.buf)
s.buf = nil
}
if c, ok := s.Reader.(io.Closer); ok {
return c.Close()
}
return nil
}