pkg/hbone/sni.go (170 lines of code) (raw):

// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package hbone import ( "context" "crypto/tls" "errors" "io" "log" "net" "strings" "time" ) // Will start a SNI proxy, similar with Istio East-West or Gateway SNI router. // Accepted connections will decode the ServerName header, and use it to forward to either a HBONE // mTLS service or a H2R connection. func (hc *Endpoint) sniProxy(ctx context.Context, stdin io.Reader, stdout io.WriteCloser) error { d := net.Dialer{} // TODO: customizations conn, err := d.DialContext(ctx, "tcp", hc.SNIGate) if Debug { log.Println("sniProxyC: ", conn.RemoteAddr(), hc.URL, hc.SNIGate) } if err != nil { return err } // Using the low-level interface, to keep control over TLS. conf := &tls.Config{} conf.ServerName = hc.SNI defer conn.Close() tlsCon := tls.Client(conn, conf) err = HandshakeTimeout(tlsCon, hc.hb.HandsahakeTimeout, nil) if err != nil { return err } return proxy(ctx, stdin, stdout, tlsCon, tlsCon) } func (hb *HBone) HandleSNIConn(conn net.Conn) { s := NewBufferReader(conn) // will also close the conn ( which is the reader ) defer s.Close() sni, err := ParseTLS(s) if err != nil { log.Println("SNI invalid TLS", sni, err) return } // Based on SNI, make a hbone request, using JWT auth. if hb.EndpointResolver != nil { dst := hb.EndpointResolver(sni) if dst != nil { if Debug { log.Println("SNI: start proxy", "sni", sni, "URL", dst.URL) } t0 := time.Now() err = dst.Proxy(context.Background(), s, conn) if err != nil { log.Println("SNI: error connecting to proxy", "sni", sni, "error", err, "URL", dst.URL) } else { log.Println("SNI:done", "sni", sni, "URL", dst.URL, "dur", time.Since(t0)) } } else { log.Println("SNI: Missing destination", "sni", sni) } } else { log.Println("SNI: Missing EndpointResolver", "sni", sni) } } var sniErr = errors.New("Invalid TLS") type ClientHelloMsg struct { // 22 vers uint16 //random []byte sessionId []byte //CipherSuites []uint16 //compressionMethods []uint8 ServerName string //ocspStapling bool //scts bool //supportedPoints []uint8 //ticketSupported bool //sessionTicket []uint8 //secureRenegotiation []byte } // TLS extension numbers const ( extensionServerName uint16 = 0 ) // TODO: if a session ID is provided, use it as a cookie and attempt // to find the corresponding host. // On server side generate session IDs ! // // TODO: in mesh, use one cypher suite (TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) // maybe 2 ( since keys are ECDSA ) func ParseTLS(acc *BufferReader) (string, error) { buf, err := acc.Fill(5) if err != nil { return "", err } typ := buf[0] // 22 3 1 2 0 if typ != 0x16 { return "", sniErr } vers := uint16(buf[1])<<8 | uint16(buf[2]) if vers != 0x301 { log.Println("Version ", vers) } rlen := int(buf[3])<<8 | int(buf[4]) if rlen > 16*1024 { log.Println("RLen ", rlen) return "", sniErr } off := 5 m := ClientHelloMsg{} end := rlen + 5 buf, err = acc.Fill(end) if err != nil { return "", err } clientHello := buf[5:end] chLen := end - 5 if chLen < 38 { log.Println("chLen ", chLen) return "", sniErr } // off is the last byte in the buffer - will be forwarded m.vers = uint16(clientHello[4])<<8 | uint16(clientHello[5]) // random: data[6:38] sessionIdLen := int(clientHello[38]) if sessionIdLen > 32 || chLen < 39+sessionIdLen { log.Println("sLen ", sessionIdLen) return "", sniErr } m.sessionId = clientHello[39 : 39+sessionIdLen] off = 39 + sessionIdLen // cipherSuiteLen is the number of bytes of cipher suite numbers. Since // they are uint16s, the number must be even. cipherSuiteLen := int(clientHello[off])<<8 | int(clientHello[off+1]) off += 2 if cipherSuiteLen%2 == 1 || chLen-off < 2+cipherSuiteLen { return "", sniErr } //numCipherSuites := cipherSuiteLen / 2 //m.cipherSuites = make([]uint16, numCipherSuites) //for i := 0; i < numCipherSuites; i++ { // m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) //} off += cipherSuiteLen compressionMethodsLen := int(clientHello[off]) off++ if chLen-off < 1+compressionMethodsLen { return "", sniErr } //m.compressionMethods = data[1 : 1+compressionMethodsLen] off += compressionMethodsLen if off+2 > chLen { // ClientHello is optionally followed by extension data return "", sniErr } extensionsLength := int(clientHello[off])<<8 | int(clientHello[off+1]) off = off + 2 if extensionsLength != chLen-off { return "", sniErr } for off < chLen { extension := uint16(clientHello[off])<<8 | uint16(clientHello[off+1]) off += 2 length := int(clientHello[off])<<8 | int(clientHello[off+1]) off += 2 if off >= end { return "", sniErr } switch extension { case extensionServerName: d := clientHello[off : off+length] if len(d) < 2 { return "", sniErr } namesLen := int(d[0])<<8 | int(d[1]) d = d[2:] if len(d) != namesLen { return "", sniErr } for len(d) > 0 { if len(d) < 3 { return "", sniErr } nameType := d[0] nameLen := int(d[1])<<8 | int(d[2]) d = d[3:] if len(d) < nameLen { return "", sniErr } if nameType == 0 { m.ServerName = string(d[:nameLen]) // An SNI value may not include a // trailing dot. See // https://tools.ietf.org/html/rfc6066#section-3. if strings.HasSuffix(m.ServerName, ".") { return "", sniErr } break } d = d[nameLen:] } default: //log.Println("TLS Ext", extension, length) } off += length } // Does not contain port !!! Assume the port is 443, or map it. // TODO: unmangle server name - port, mesh node return m.ServerName, nil }