void AsyncSSLSocket::clientHelloParsingCallback()

in folly/io/async/AsyncSSLSocket.cpp [2027:2168]


void AsyncSSLSocket::clientHelloParsingCallback(
    int written,
    int /* version */,
    int contentType,
    const void* buf,
    size_t len,
    SSL* ssl,
    void* arg) {
  auto sock = static_cast<AsyncSSLSocket*>(arg);
  if (written != 0) {
    sock->resetClientHelloParsing(ssl);
    return;
  }
  if (contentType != SSL3_RT_HANDSHAKE) {
    return;
  }
  if (len == 0) {
    return;
  }

  auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
  clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
  try {
    Cursor cursor(clientHelloBuf.front());
    if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
      sock->resetClientHelloParsing(ssl);
      return;
    }

    if (cursor.totalLength() < 3) {
      clientHelloBuf.trimEnd(len);
      clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
      return;
    }

    uint32_t messageLength = cursor.read<uint8_t>();
    messageLength <<= 8;
    messageLength |= cursor.read<uint8_t>();
    messageLength <<= 8;
    messageLength |= cursor.read<uint8_t>();
    if (cursor.totalLength() < messageLength) {
      clientHelloBuf.trimEnd(len);
      clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
      return;
    }

    sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
    sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();

    cursor.skip(4); // gmt_unix_time
    cursor.skip(28); // random_bytes

    cursor.skip(cursor.read<uint8_t>()); // session_id

    auto cipherSuitesLength = cursor.readBE<uint16_t>();
    for (int i = 0; i < cipherSuitesLength; i += 2) {
      sock->clientHelloInfo_->clientHelloCipherSuites_.push_back(
          cursor.readBE<uint16_t>());
    }

    auto compressionMethodsLength = cursor.read<uint8_t>();
    for (int i = 0; i < compressionMethodsLength; ++i) {
      sock->clientHelloInfo_->clientHelloCompressionMethods_.push_back(
          cursor.readBE<uint8_t>());
    }

    if (cursor.totalLength() > 0) {
      auto extensionsLength = cursor.readBE<uint16_t>();
      while (extensionsLength) {
        auto extensionType =
            static_cast<ssl::TLSExtension>(cursor.readBE<uint16_t>());
        sock->clientHelloInfo_->clientHelloExtensions_.push_back(extensionType);
        extensionsLength -= 2;
        auto extensionDataLength = cursor.readBE<uint16_t>();
        extensionsLength -= 2;
        extensionsLength -= extensionDataLength;

        if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) {
          cursor.skip(2);
          extensionDataLength -= 2;
          while (extensionDataLength) {
            auto hashAlg =
                static_cast<ssl::HashAlgorithm>(cursor.readBE<uint8_t>());
            auto sigAlg =
                static_cast<ssl::SignatureAlgorithm>(cursor.readBE<uint8_t>());
            extensionDataLength -= 2;
            sock->clientHelloInfo_->clientHelloSigAlgs_.emplace_back(
                hashAlg, sigAlg);
          }
        } else if (extensionType == ssl::TLSExtension::SUPPORTED_VERSIONS) {
          cursor.skip(1);
          extensionDataLength -= 1;
          while (extensionDataLength) {
            sock->clientHelloInfo_->clientHelloSupportedVersions_.push_back(
                cursor.readBE<uint16_t>());
            extensionDataLength -= 2;
          }
        } else if (extensionType == ssl::TLSExtension::SERVER_NAME) {
          cursor.skip(2);
          extensionDataLength -= 2;
          while (extensionDataLength) {
            static_assert(
                std::is_same<
                    typename std::underlying_type<ssl::NameType>::type,
                    uint8_t>::value,
                "unexpected underlying type");

            auto typ = static_cast<ssl::NameType>(cursor.readBE<uint8_t>());
            auto nameLength = cursor.readBE<uint16_t>();

            if (typ == NameType::HOST_NAME &&
                sock->clientHelloInfo_->clientHelloSNIHostname_.empty() &&
                cursor.canAdvance(nameLength)) {
              sock->clientHelloInfo_->clientHelloSNIHostname_ =
                  cursor.readFixedString(nameLength);
            } else {
              // Must attempt to skip |nameLength| in order to keep cursor
              // in sync. If the remaining buffer length is smaller than
              // nameLength, this will throw.
              cursor.skip(nameLength);
            }
            extensionDataLength -=
                sizeof(typ) + sizeof(nameLength) + nameLength;
          }
        } else if (
            extensionType ==
            ssl::TLSExtension::APPLICATION_LAYER_PROTOCOL_NEGOTIATION) {
          parseClientAlpns(sock, cursor, extensionDataLength);
        } else {
          cursor.skip(extensionDataLength);
        }
      }
    }
  } catch (std::out_of_range&) {
    // we'll use what we found and cleanup below.
    VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
            << "buffer finished unexpectedly."
            << " AsyncSSLSocket socket=" << sock;
  }

  sock->resetClientHelloParsing(ssl);
}