packetbeat/protos/tls/parse.go (575 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package tls
import (
"crypto/dsa" //nolint:staticcheck // SA1019 Deprecated, but still used. So we have to handle it.
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"fmt"
"strings"
"github.com/elastic/beats/v7/libbeat/common/streambuf"
"github.com/elastic/elastic-agent-libs/logp"
"github.com/elastic/elastic-agent-libs/mapstr"
)
type direction uint8
const (
dirUnknown direction = iota
dirClient
dirServer
)
const (
maxTLSRecordLength = (1 << 14) + 2048
// For safety, ignore handshake messages longer than 64k (same as stdlib)
maxHandshakeSize = 1 << 16
recordHeaderSize = 5
handshakeHeaderSize = 4
helloHeaderLength = 7
randomDataLength = 28
)
type recordType uint8
const (
recordTypeChangeCipherSpec recordType = 20
recordTypeAlert recordType = 21
recordTypeHandshake recordType = 22
recordTypeApplicationData recordType = 23
)
type handshakeType uint8
const (
helloRequest handshakeType = 0
clientHello handshakeType = 1
serverHello handshakeType = 2
certificate handshakeType = 11
serverKeyExchange handshakeType = 12
certificateRequest handshakeType = 13
clientKeyExchange handshakeType = 16
certificateStatus handshakeType = 22
)
type parserResult int8
const (
resultOK parserResult = iota
resultFailed
resultMore
resultEncrypted
)
type tlsTicket struct {
present bool
value string
}
type parser struct {
// Buffer to accumulate records until a full handshake message
// is received
handshakeBuf streambuf.Buffer
direction direction
alerts []alert
certificates []*x509.Certificate
hello *helloMessage
// If this end of the connection (server) asked the other end (client)
// for a certificate
certRequested bool
// ocspResponse is the top level OCSP response status.
ocspResponse ocspResponseStatus
ocspResponseIsValid bool
// If a key-exchange message has been sent. Used to detect session resumption
keyExchanged bool
}
// https://www.rfc-editor.org/rfc/rfc6960#section-4.2.1
type ocspResponseStatus byte
func (s ocspResponseStatus) String() string {
switch s {
case 0: // Response has valid confirmations
return "successful"
case 1: // Illegal confirmation request
return "malformedRequest"
case 2: // Internal error in issuer
return "internalError"
case 3: // Try again later
return "tryLater"
case 5: // Must sign the request
return "sigRequired"
case 6: // Request unauthorized
return "unauthorized"
default:
return fmt.Sprint(byte(s))
}
}
type tlsVersion struct {
major, minor uint8
}
type recordHeader struct {
recordType recordType
version tlsVersion
length uint16
}
type handshakeHeader struct {
handshakeType handshakeType
length int
}
type helloMessage struct {
version tlsVersion
random []byte
timestamp uint32
sessionID string
ticket tlsTicket
supported struct {
cipherSuites []cipherSuite
compression []compressionMethod
}
selected struct {
cipherSuite cipherSuite
compression compressionMethod
}
extensions Extensions
}
func readRecordHeader(buf *streambuf.Buffer) (*recordHeader, error) {
var (
header recordHeader
err error
record uint8
)
if record, err = buf.ReadNetUint8At(0); err != nil {
return nil, err
}
header.recordType = recordType(record)
if header.version.major, err = buf.ReadNetUint8At(1); err != nil {
return nil, err
}
if header.version.minor, err = buf.ReadNetUint8At(2); err != nil {
return nil, err
}
if header.length, err = buf.ReadNetUint16At(3); err != nil {
return nil, err
}
return &header, nil
}
func readHandshakeHeader(buf *streambuf.Buffer) (*handshakeHeader, error) {
var err error
var len8, typ uint8
var len16 uint16
if typ, err = buf.ReadNetUint8At(0); err != nil {
return nil, err
}
if len8, err = buf.ReadNetUint8At(1); err != nil {
return nil, err
}
if len16, err = buf.ReadNetUint16At(2); err != nil {
return nil, err
}
return &handshakeHeader{
handshakeType(typ),
int(len16) | (int(len8) << 16),
}, nil
}
func (header *recordHeader) String() string {
return fmt.Sprintf("recordHeader type[%v] version[%v] length[%d]",
header.recordType, header.version, header.length)
}
func (header *recordHeader) isValid() bool {
return header.version.major == 3 && header.length <= maxTLSRecordLength
}
func (hello *helloMessage) toMap() mapstr.M {
m := mapstr.M{
"version": fmt.Sprintf("%d.%d", hello.version.major, hello.version.minor),
}
if len(hello.sessionID) != 0 {
m["session_id"] = hello.sessionID
}
if len(hello.random) != 0 {
m["random"] = hex.EncodeToString(hello.random)
}
if len(hello.supported.compression) > 0 {
comp := make([]string, len(hello.supported.compression))
for idx, code := range hello.supported.compression {
comp[idx] = code.String()
}
m["supported_compression_methods"] = comp
} else {
m["selected_compression_method"] = hello.selected.compression.String()
}
if hello.extensions.Parsed != nil {
m["extensions"] = hello.extensions.Parsed
}
return m
}
func (hello *helloMessage) supportedCiphers() []string {
ciphers := make([]string, len(hello.supported.cipherSuites))
for idx, code := range hello.supported.cipherSuites {
ciphers[idx] = code.String()
}
return ciphers
}
func (parser *parser) parse(buf *streambuf.Buffer) parserResult {
for buf.Avail(recordHeaderSize) {
header, err := readRecordHeader(buf)
if err != nil || !header.isValid() {
if err != nil {
logp.Warn("internal buffer error: %v", err)
}
return resultFailed
}
limit := recordHeaderSize + int(header.length)
if !buf.Avail(limit) {
// wait for complete record
return resultMore
}
switch header.recordType {
case recordTypeChangeCipherSpec: // single message of size 1 (byte 1)
if isDebug {
debugf("handshake completed")
}
// discard remaining data for this stream (encrypted)
_ = buf.Advance(buf.Len())
return resultEncrypted
case recordTypeHandshake:
if isDebug {
debugf("got handshake record of size %d", header.length)
}
if err = parser.bufferHandshake(buf, int(header.length)); err != nil {
logp.Warn("Error parsing handshake message: %v", err)
return resultFailed
}
case recordTypeAlert:
if err = parser.parseAlert(newBufferView(buf, recordHeaderSize, int(header.length))); err != nil {
logp.Warn("Error parsing alert message: %v", err)
return resultFailed
}
case recordTypeApplicationData:
// TODO: Request / Response analytics
if isDebug {
debugf("ignoring application data length %d", header.length)
}
default:
if isDebug {
debugf("ignoring record type %d length %d", header.recordType, header.length)
}
}
_ = buf.Advance(limit)
}
if buf.Len() == 0 {
return resultOK
}
return resultMore
}
func (parser *parser) bufferHandshake(buf *streambuf.Buffer, length int) (err error) {
// TODO: parse in-place if message in received buffer is complete
err = parser.handshakeBuf.Append(buf.Bytes()[recordHeaderSize : recordHeaderSize+length])
if err != nil {
logp.Warn("failed appending to buffer: %v", err)
// Discard buffer
parser.handshakeBuf.Init(nil, false)
return err
}
// Recover from any bufferView.subview out of bounds errors.
defer func() {
r := recover()
switch r := r.(type) {
case nil:
case bufferViewError:
err = r
default:
panic(r)
}
}()
for parser.handshakeBuf.Avail(handshakeHeaderSize) {
// type
header, err := readHandshakeHeader(&parser.handshakeBuf)
if err != nil {
logp.Warn("read failed: %v", err)
parser.handshakeBuf.Init(nil, false)
return err
}
if header.length > maxHandshakeSize {
// Discard buffer
parser.handshakeBuf.Init(nil, false)
return fmt.Errorf("message too large (%d bytes)", header.length)
}
limit := handshakeHeaderSize + header.length
if limit > parser.handshakeBuf.Len() {
break
}
if !parser.parseHandshake(header.handshakeType,
bufferView{&parser.handshakeBuf, handshakeHeaderSize, limit}) {
_ = parser.handshakeBuf.Advance(limit)
return fmt.Errorf("bad handshake %+v", header)
}
_ = parser.handshakeBuf.Advance(limit)
}
if parser.handshakeBuf.Len() == 0 {
parser.handshakeBuf.Reset()
}
return nil
}
func (parser *parser) setDirection(dir direction) {
if parser.direction != dir && parser.direction != dirUnknown {
logp.Warn("client/server identification mismatch")
}
parser.direction = dir
}
func (parser *parser) parseHandshake(handshakeType handshakeType, buffer bufferView) bool {
if isDebug {
debugf("got handshake message %v [%d]", handshakeType, buffer.length())
}
switch handshakeType {
case helloRequest:
parser.setDirection(dirServer)
return parseHelloRequest(buffer)
case clientHello:
parser.setDirection(dirClient)
if parser.hello = parseClientHello(buffer); parser.hello == nil {
return false
}
return true
case serverHello:
parser.setDirection(dirServer)
if parser.hello = parseServerHello(buffer); parser.hello == nil {
return false
}
return true
case certificate:
certs := parseCertificates(buffer)
parser.certificates = append(parser.certificates, certs...)
case certificateRequest:
parser.setDirection(dirServer)
parser.certRequested = true
case clientKeyExchange:
parser.setDirection(dirClient)
parser.keyExchanged = true
case serverKeyExchange:
parser.setDirection(dirServer)
parser.keyExchanged = true
case certificateStatus:
parser.ocspResponse, parser.ocspResponseIsValid = parseOCSPStatus(buffer)
}
return true
}
func parseHelloRequest(buffer bufferView) bool {
if buffer.length() != 0 {
logp.Warn("non-empty hello request")
}
return true
}
func parseCommonHello(buffer bufferView, dest *helloMessage) (int, bool) {
var sessionIDLength uint8
if !buffer.read8(0, &dest.version.major) ||
!buffer.read8(1, &dest.version.minor) ||
!buffer.read32Net(2, &dest.timestamp) ||
// ignore 28 random bytes
!buffer.read8(6+randomDataLength, &sessionIDLength) {
logp.Warn("failed reading hello message")
return 0, false
}
if dest.version.major != 3 {
logp.Warn("Not a TLS hello (reported version %d.%d)",
dest.version.major, dest.version.minor)
return 0, false
}
if sessionIDLength > 32 {
logp.Warn("Not a TLS hello (session id length %d out of bounds)", sessionIDLength)
return 0, false
}
bytes := buffer.readBytes(7+randomDataLength, int(sessionIDLength))
if len(bytes) != int(sessionIDLength) {
logp.Warn("Not a TLS hello (failed reading session ID)")
return 0, false
}
dest.sessionID = hex.EncodeToString(bytes)
dest.random = buffer.readBytes(2, 4+randomDataLength)
return helloHeaderLength + randomDataLength + int(sessionIDLength), true
}
func (hello *helloMessage) parseExtensions(buffer bufferView) {
hello.extensions = ParseExtensions(buffer)
if ticket, err := hello.extensions.Parsed.GetValue("session_ticket"); err == nil {
if value, ok := ticket.(string); ok {
hello.ticket.present = true
hello.ticket.value = value
} else {
logp.Err("tls ticket data type error")
}
}
}
func parseClientHello(buffer bufferView) *helloMessage {
var result helloMessage
pos, ok := parseCommonHello(buffer, &result)
if !ok {
return nil
}
var cipherSuitesLength uint16
if !buffer.read16Net(pos, &cipherSuitesLength) {
logp.Warn("failed parsing client hello cipher suite length")
return nil
}
for base := pos + 2; base < pos+2+int(cipherSuitesLength); base += 2 {
var cipher uint16
if !buffer.read16Net(base, &cipher) {
logp.Warn("failed parsing client hello cipher suite")
return nil
}
if !isGreaseValue(cipher) {
result.supported.cipherSuites = append(result.supported.cipherSuites, cipherSuite(cipher))
}
}
pos += 2 + int(cipherSuitesLength)
var compMethodsLength uint8
if !buffer.read8(pos, &compMethodsLength) {
logp.Warn("failed parsing client hello compression methods length")
return nil
}
limit := pos + 1 + int(compMethodsLength)
for base := pos + 1; base < limit; base++ {
var method uint8
if !buffer.read8(base, &method) {
logp.Warn("failed parsing client hello compression methods")
return nil
}
result.supported.compression = append(result.supported.compression, compressionMethod(method))
}
result.parseExtensions(buffer.subview(limit, buffer.limit-limit))
return &result
}
func parseServerHello(buffer bufferView) *helloMessage {
var result helloMessage
pos, ok := parseCommonHello(buffer, &result)
if !ok {
return nil
}
var cipher uint16
var compression uint8
if !buffer.read16Net(pos, &cipher) ||
!buffer.read8(pos+2, &compression) {
return nil
}
result.selected.cipherSuite = cipherSuite(cipher)
result.selected.compression = compressionMethod(compression)
result.parseExtensions(buffer.subview(pos+3, buffer.limit-pos-3))
return &result
}
func parseCertificates(buffer bufferView) (certs []*x509.Certificate) {
var totalLen uint32
if !buffer.read24Net(0, &totalLen) || int(totalLen+3) != buffer.length() {
return nil
}
for pos, limit := 3, int(totalLen)+3; pos+3 <= limit; {
var certLen uint32
if !buffer.read24Net(pos, &certLen) || pos+3+int(certLen) > limit {
return nil
}
raw := buffer.readBytes(pos+3, int(certLen))
if len(raw) != int(certLen) {
return nil
}
parsed, err := x509.ParseCertificate(raw)
if err != nil {
return nil
}
certs = append(certs, parsed)
pos += 3 + int(certLen)
}
return certs
}
func parseOCSPStatus(buffer bufferView) (status ocspResponseStatus, ok bool) {
const (
statusTypeLen = 1
respLengthLen = 3
ocspRespHeaderLen = 6
ocspStatusType = 1
)
var b byte
ok = buffer.read8(0, &b)
if !ok || b != ocspStatusType {
return 0, false
}
ok = buffer.read8(statusTypeLen+respLengthLen+ocspRespHeaderLen, &b)
return ocspResponseStatus(b), ok
}
func (version tlsVersion) String() string {
if version.major == 3 {
if version.minor > 0 {
return fmt.Sprintf("TLS 1.%d", version.minor-1)
}
return "SSL 3.0"
}
return fmt.Sprintf("(raw %d.%d)", version.major, version.minor)
}
// ProtocolVersion represents a version of the TLS protocol.
type ProtocolVersion struct {
// Protocol in use. One of "tls", "ssl" or "unknown".
Protocol string
// Version is the protocol version, as in "1.3" for tls or "3.0" for ssl.
Version string
}
// GetProtocolVersion returns the protocol and protocol version number
// associated to the raw TLS protocol version.
func (version tlsVersion) GetProtocolVersion() ProtocolVersion {
if version.major == 3 {
if version.minor == 0 {
return ProtocolVersion{Protocol: "ssl", Version: "3.0"}
}
return ProtocolVersion{Protocol: "tls", Version: fmt.Sprintf("1.%d", version.minor-1)}
}
return ProtocolVersion{Protocol: "unknown", Version: fmt.Sprintf("%d.%d", version.major, version.minor)}
}
// IsZero returns if this version is the zero value (unset).
func (version tlsVersion) IsZero() bool {
return version.major == 0 && version.minor == 0
}
func getKeySize(key interface{}) int {
if key == nil {
return 0
}
switch pubKey := key.(type) {
case *rsa.PublicKey:
if n := pubKey.N; n != nil {
return n.BitLen()
}
case *dsa.PublicKey:
if p := pubKey.Parameters.P; p != nil {
return p.BitLen()
}
if y := pubKey.Y; y != nil {
return y.BitLen()
}
case *ecdsa.PublicKey:
if params := pubKey.Params(); params != nil {
return params.BitSize
}
if y := pubKey.Y; y != nil {
return y.BitLen()
}
}
return 0
}
// certToMap takes an x509 cert and converts it into a map.
func certToMap(cert *x509.Certificate) mapstr.M {
certMap := mapstr.M{
"signature_algorithm": cert.SignatureAlgorithm.String(),
"public_key_algorithm": toString(cert.PublicKeyAlgorithm),
"serial_number": strings.ToUpper(cert.SerialNumber.Text(16)),
"issuer": toMap(&cert.Issuer),
"subject": toMap(&cert.Subject),
"not_before": cert.NotBefore,
"not_after": cert.NotAfter,
"version_number": cert.Version,
}
if keySize := getKeySize(cert.PublicKey); keySize > 0 {
certMap["public_key_size"] = keySize
}
san := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses)+len(cert.EmailAddresses))
san = append(append(san, cert.DNSNames...), cert.EmailAddresses...)
for _, ip := range cert.IPAddresses {
san = append(san, ip.String())
}
if len(san) > 0 {
certMap["alternative_names"] = san
}
return certMap
}
func toMap(name *pkix.Name) mapstr.M {
result := mapstr.M{}
fields := []struct {
name string
value interface{}
}{
{"country", name.Country},
{"organization", name.Organization},
{"organizational_unit", name.OrganizationalUnit},
{"locality", name.Locality},
{"postal_code", name.PostalCode},
{"serial_number", name.SerialNumber},
{"common_name", name.CommonName},
{"street_address", name.StreetAddress},
{"state_or_province", name.Province},
{"distinguished_name", name.String()},
}
for _, field := range fields {
var str string
switch value := field.value.(type) {
case string:
str = value
case []string:
str = strings.Join(value, " ")
}
if len(str) > 0 {
result[field.name] = str
}
}
return result
}
func (parser *parser) hasInfo() bool {
return parser.hello != nil || len(parser.alerts) != 0 || len(parser.certificates) != 0
}