quic/server/QuicServerWorker.cpp (1,184 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <fmt/format.h>
#include <folly/chrono/Conv.h>
#include <folly/io/Cursor.h>
#include <folly/io/SocketOptionMap.h>
#include <folly/system/ThreadId.h>
#include <quic/QuicConstants.h>
#include <quic/common/SocketUtil.h>
#include <quic/common/Timers.h>
#include <atomic>
#ifdef FOLLY_HAVE_MSG_ERRQUEUE
#include <linux/net_tstamp.h>
#else
#define SOF_TIMESTAMPING_SOFTWARE 0
#endif
#include <folly/Conv.h>
#include <quic/congestion_control/Bbr.h>
#include <quic/congestion_control/Copa.h>
#include <quic/fizz/handshake/FizzRetryIntegrityTagGenerator.h>
#include <quic/server/AcceptObserver.h>
#include <quic/server/CCPReader.h>
#include <quic/server/QuicServerWorker.h>
#include <quic/server/handshake/StatelessResetGenerator.h>
#include <quic/server/handshake/TokenGenerator.h>
#include <quic/state/QuicConnectionStats.h>
namespace quic {
std::atomic_int globalUnfinishedHandshakes{0};
QuicServerWorker::QuicServerWorker(
std::shared_ptr<QuicServerWorker::WorkerCallback> callback,
bool setEventCallback)
: callback_(callback),
setEventCallback_(setEventCallback),
takeoverPktHandler_(this),
observerList_(this) {
ccpReader_ = std::make_unique<CCPReader>();
pending0RttData_.setPruneHook(
[&](auto, auto) { QUIC_STATS(statsCallback_, onZeroRttBufferedPruned); });
}
folly::EventBase* QuicServerWorker::getEventBase() const {
return evb_;
}
void QuicServerWorker::setSocket(
std::unique_ptr<folly::AsyncUDPSocket> socket) {
socket_ = std::move(socket);
evb_ = socket_->getEventBase();
}
void QuicServerWorker::bind(
const folly::SocketAddress& address,
folly::AsyncUDPSocket::BindOptions bindOptions) {
DCHECK(!supportedVersions_.empty());
CHECK(socket_);
if (setEventCallback_) {
socket_->setEventCallback(this);
}
// TODO this totally doesn't work, we can't apply socket options before
// bind, since bind creates the fd.
if (socketOptions_) {
applySocketOptions(
*socket_.get(),
*socketOptions_,
address.getFamily(),
folly::SocketOptionKey::ApplyPos::PRE_BIND);
}
socket_->bind(address, bindOptions);
if (socketOptions_) {
applySocketOptions(
*socket_.get(),
*socketOptions_,
address.getFamily(),
folly::SocketOptionKey::ApplyPos::POST_BIND);
}
socket_->setDFAndTurnOffPMTU();
if (transportSettings_.numGROBuffers_ > kDefaultNumGROBuffers) {
socket_->setGRO(true);
auto ret = socket_->getGRO();
if (ret > 0) {
numGROBuffers_ = (transportSettings_.numGROBuffers_ < kMaxNumGROBuffers)
? transportSettings_.numGROBuffers_
: kMaxNumGROBuffers;
}
}
socket_->setTimestamping(SOF_TIMESTAMPING_SOFTWARE);
}
void QuicServerWorker::applyAllSocketOptions() {
CHECK(socket_);
if (socketOptions_) {
applySocketOptions(
*socket_.get(),
*socketOptions_,
getAddress().getFamily(),
folly::SocketOptionKey::ApplyPos::PRE_BIND);
applySocketOptions(
*socket_.get(),
*socketOptions_,
getAddress().getFamily(),
folly::SocketOptionKey::ApplyPos::POST_BIND);
}
}
void QuicServerWorker::setTransportSettingsOverrideFn(
TransportSettingsOverrideFn fn) {
transportSettingsOverrideFn_ = std::move(fn);
}
void QuicServerWorker::setTransportStatsCallback(
std::unique_ptr<QuicTransportStatsCallback> statsCallback) noexcept {
CHECK(statsCallback);
statsCallback_ = std::move(statsCallback);
}
QuicTransportStatsCallback* QuicServerWorker::getTransportStatsCallback()
const noexcept {
return statsCallback_.get();
}
void QuicServerWorker::setConnectionIdAlgo(
std::unique_ptr<ConnectionIdAlgo> connIdAlgo) noexcept {
CHECK(connIdAlgo);
connIdAlgo_ = std::move(connIdAlgo);
}
void QuicServerWorker::setCongestionControllerFactory(
std::shared_ptr<CongestionControllerFactory> ccFactory) {
CHECK(ccFactory);
ccFactory_ = ccFactory;
}
void QuicServerWorker::setRateLimiter(
std::unique_ptr<RateLimiter> rateLimiter) {
newConnRateLimiter_ = std::move(rateLimiter);
}
void QuicServerWorker::setUnfinishedHandshakeLimit(
std::function<int()> limitFn) {
unfinishedHandshakeLimitFn_ = std::move(limitFn);
}
void QuicServerWorker::start() {
CHECK(socket_);
if (!pacingTimer_) {
pacingTimer_ = TimerHighRes::newTimer(
evb_, transportSettings_.pacingTimerTickInterval);
}
socket_->resumeRead(this);
VLOG(10) << fmt::format(
"Registered read on worker={}, thread={}, processId={}",
fmt::ptr(this),
folly::getCurrentThreadID(),
(int)processId_);
}
void QuicServerWorker::pauseRead() {
CHECK(socket_);
socket_->pauseRead();
}
int QuicServerWorker::getFD() {
CHECK(socket_);
return socket_->getNetworkSocket().toFd();
}
const folly::SocketAddress& QuicServerWorker::getAddress() const {
CHECK(socket_);
return socket_->address();
}
void QuicServerWorker::getReadBuffer(void** buf, size_t* len) noexcept {
readBuffer_ = folly::IOBuf::create(
transportSettings_.maxRecvPacketSize * numGROBuffers_);
*buf = readBuffer_->writableData();
*len = transportSettings_.maxRecvPacketSize * numGROBuffers_;
}
// Returns true if we either drop the packet or send a version
// negotiation packet to the client. Returns false if there's
// no need for version negotiation.
bool QuicServerWorker::maybeSendVersionNegotiationPacketOrDrop(
const folly::SocketAddress& client,
bool isInitial,
LongHeaderInvariant& invariant,
size_t datagramLen) {
folly::Optional<std::pair<VersionNegotiationPacket, Buf>>
versionNegotiationPacket;
if (isInitial && datagramLen < kMinInitialPacketSize) {
VLOG(3) << "Dropping initial packet due to invalid size";
QUIC_STATS(
statsCallback_, onPacketDropped, PacketDropReason::INVALID_PACKET);
return true;
}
isInitial =
isInitial && invariant.version != QuicVersion::VERSION_NEGOTIATION;
if (rejectNewConnections_() && isInitial) {
VersionNegotiationPacketBuilder builder(
invariant.dstConnId,
invariant.srcConnId,
std::vector<QuicVersion>{QuicVersion::MVFST_INVALID});
versionNegotiationPacket =
folly::make_optional(std::move(builder).buildPacket());
}
if (!versionNegotiationPacket) {
bool negotiationNeeded = std::find(
supportedVersions_.begin(),
supportedVersions_.end(),
invariant.version) == supportedVersions_.end();
if (negotiationNeeded && !isInitial) {
VLOG(3) << "Dropping non-initial packet due to invalid version";
QUIC_STATS(
statsCallback_, onPacketDropped, PacketDropReason::INVALID_PACKET);
return true;
}
if (negotiationNeeded) {
VersionNegotiationPacketBuilder builder(
invariant.dstConnId, invariant.srcConnId, supportedVersions_);
versionNegotiationPacket =
folly::make_optional(std::move(builder).buildPacket());
}
}
if (versionNegotiationPacket) {
VLOG(4) << "Version negotiation sent to client=" << client;
auto len = versionNegotiationPacket->second->computeChainDataLength();
QUIC_STATS(statsCallback_, onWrite, len);
QUIC_STATS(statsCallback_, onPacketProcessed);
QUIC_STATS(statsCallback_, onPacketSent);
socket_->write(client, versionNegotiationPacket->second);
return true;
}
return false;
}
void QuicServerWorker::onDataAvailable(
const folly::SocketAddress& client,
size_t len,
bool truncated,
OnDataAvailableParams params) noexcept {
auto packetReceiveTime = Clock::now();
auto originalPacketReceiveTime = packetReceiveTime;
if (params.ts) {
// This is the software system time from the datagram.
auto packetNowDuration =
folly::to<std::chrono::microseconds>(params.ts.value()[0]);
auto wallNowDuration =
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch());
auto durationSincePacketNow = wallNowDuration - packetNowDuration;
if (packetNowDuration != 0us && durationSincePacketNow > 0us) {
packetReceiveTime -= durationSincePacketNow;
}
}
// System time can move backwards, so we want to make sure that the receive
// time we are using is monotonic relative to itself.
if (packetReceiveTime < largestPacketReceiveTime_) {
packetReceiveTime = originalPacketReceiveTime;
}
largestPacketReceiveTime_ =
std::max(largestPacketReceiveTime_, packetReceiveTime);
VLOG(10) << fmt::format(
"Worker={}, Received data on thread={}, processId={}",
fmt::ptr(this),
folly::getCurrentThreadID(),
(int)processId_);
// Move readBuffer_ first so that we can get rid
// of it immediately so that if we return early,
// we've flushed it.
Buf data = std::move(readBuffer_);
if (params.gro <= 0) {
if (truncated) {
// This is an error, drop the packet.
return;
}
data->append(len);
QUIC_STATS(statsCallback_, onPacketReceived);
QUIC_STATS(statsCallback_, onRead, len);
handleNetworkData(client, std::move(data), packetReceiveTime);
} else {
// if we receive a truncated packet
// we still need to consider the prev valid ones
// AsyncUDPSocket::handleRead() sets the len to be the
// buffer size in case the data is truncated
if (truncated) {
len -= len % params.gro;
}
data->append(len);
QUIC_STATS(statsCallback_, onPacketReceived);
QUIC_STATS(statsCallback_, onRead, len);
size_t remaining = len;
size_t offset = 0;
while (remaining) {
if (static_cast<int>(remaining) > params.gro) {
auto tmp = data->cloneOne();
// start at offset
tmp->trimStart(offset);
// the actual len is len - offset now
// leave params.gro_ bytes
tmp->trimEnd(len - offset - params.gro);
DCHECK_EQ(tmp->length(), params.gro);
offset += params.gro;
remaining -= params.gro;
handleNetworkData(client, std::move(tmp), packetReceiveTime);
} else {
// do not clone the last packet
// start at offset, use all the remaining data
data->trimStart(offset);
DCHECK_EQ(data->length(), remaining);
remaining = 0;
handleNetworkData(client, std::move(data), packetReceiveTime);
}
}
}
}
void QuicServerWorker::handleNetworkData(
const folly::SocketAddress& client,
Buf data,
const TimePoint& packetReceiveTime,
bool isForwardedData) noexcept {
try {
if (shutdown_) {
VLOG(4) << "Packet received after shutdown, dropping";
QUIC_STATS(
statsCallback_, onPacketDropped, PacketDropReason::SERVER_SHUTDOWN);
return;
}
if (isBlockListedSrcPort_(client.getPort())) {
VLOG(4) << "Dropping packet with blocklisted src port: "
<< client.getPort();
QUIC_STATS(
statsCallback_, onPacketDropped, PacketDropReason::INVALID_SRC_PORT);
return;
}
if (!callback_) {
VLOG(0) << "Worker callback is null. Dropping packet.";
QUIC_STATS(
statsCallback_,
onPacketDropped,
PacketDropReason::WORKER_NOT_INITIALIZED);
return;
}
folly::io::Cursor cursor(data.get());
if (!cursor.canAdvance(sizeof(uint8_t))) {
VLOG(4) << "Dropping packet too small";
QUIC_STATS(
statsCallback_, onPacketDropped, PacketDropReason::INVALID_PACKET);
return;
}
uint8_t initialByte = cursor.readBE<uint8_t>();
HeaderForm headerForm = getHeaderForm(initialByte);
if (headerForm == HeaderForm::Short) {
folly::Expected<ShortHeaderInvariant, TransportErrorCode>
parsedShortHeader = parseShortHeaderInvariants(initialByte, cursor);
if (!parsedShortHeader) {
if (!tryHandlingAsHealthCheck(client, *data)) {
QUIC_STATS(
statsCallback_, onPacketDropped, PacketDropReason::PARSE_ERROR);
VLOG(6) << "Failed to parse short header";
}
return;
}
RoutingData routingData(
headerForm,
false, /* isInitial */
false, /* is0Rtt */
false, /* isUsingClientConnId */
std::move(parsedShortHeader->destinationConnId),
folly::none);
return forwardNetworkData(
client,
std::move(routingData),
NetworkData(std::move(data), packetReceiveTime),
folly::none, /* quicVersion */
isForwardedData);
}
folly::Expected<ParsedLongHeaderInvariant, TransportErrorCode>
parsedLongHeader = parseLongHeaderInvariant(initialByte, cursor);
if (!parsedLongHeader) {
if (!tryHandlingAsHealthCheck(client, *data)) {
QUIC_STATS(
statsCallback_, onPacketDropped, PacketDropReason::PARSE_ERROR);
VLOG(6) << "Failed to parse long header";
}
return;
}
// TODO: check version before looking at type
LongHeader::Types longHeaderType = parseLongHeaderType(initialByte);
bool isInitial = longHeaderType == LongHeader::Types::Initial;
bool is0Rtt = longHeaderType == LongHeader::Types::ZeroRtt;
bool isUsingClientConnId = isInitial || is0Rtt;
if (isInitial) {
// This stats gets updated even if the client initial will be dropped.
QUIC_STATS(
statsCallback_,
onClientInitialReceived,
parsedLongHeader->invariant.version);
}
if (maybeSendVersionNegotiationPacketOrDrop(
client,
isInitial,
parsedLongHeader->invariant,
data->computeChainDataLength())) {
return;
}
if (!isUsingClientConnId &&
parsedLongHeader->invariant.dstConnId.size() <
kMinSelfConnectionIdV1Size) {
// drop packet if connId is present but is not valid.
VLOG(3) << "Dropping packet due to invalid connectionId";
QUIC_STATS(
statsCallback_, onPacketDropped, PacketDropReason::INVALID_PACKET);
return;
}
RoutingData routingData(
headerForm,
isInitial,
is0Rtt,
isUsingClientConnId,
std::move(parsedLongHeader->invariant.dstConnId),
std::move(parsedLongHeader->invariant.srcConnId));
return forwardNetworkData(
client,
std::move(routingData),
NetworkData(std::move(data), packetReceiveTime),
parsedLongHeader->invariant.version,
isForwardedData);
} catch (const std::exception& ex) {
// Drop the packet.
QUIC_STATS(statsCallback_, onPacketDropped, PacketDropReason::PARSE_ERROR);
VLOG(6) << "Failed to parse packet header " << ex.what();
}
}
void QuicServerWorker::eventRecvmsgCallback(MsgHdr* msgHdr, int res) {
auto bytesRead = res;
auto& msg = msgHdr->data_;
if (bytesRead > 0) {
OnDataAvailableParams params;
#ifdef FOLLY_HAVE_MSG_ERRQUEUE
if (msgHdr->data_.msg_control) {
folly::AsyncUDPSocket::fromMsg(params, msg);
}
#endif
bool truncated = false;
if ((size_t)bytesRead > msgHdr->len_) {
truncated = true;
bytesRead = ssize_t(msgHdr->len_);
}
readBuffer_ = std::move(msgHdr->ioBuf_);
folly::SocketAddress addr;
addr.setFromSockaddr(
reinterpret_cast<sockaddr*>(msg.msg_name), msg.msg_namelen);
onDataAvailable(addr, bytesRead, truncated, params);
}
msgHdr_.reset(msgHdr);
}
bool QuicServerWorker::tryHandlingAsHealthCheck(
const folly::SocketAddress& client,
const folly::IOBuf& data) {
// If we cannot parse the long header then it is not a QUIC invariant
// packet, so just drop it after checking whether it could be a health
// check.
if (!healthCheckToken_) {
return false;
}
folly::IOBufEqualTo eq;
// TODO: make this constant time, the token might be secret, but we're
// current assuming it's not.
if (eq(*healthCheckToken_.value(), data)) {
// say that we are OK. The response is much smaller than the
// request, so we are not creating an amplification vector. Also
// ignore the error code.
VLOG(4) << "Health check request, response=OK";
socket_->write(client, folly::IOBuf::copyBuffer("OK"));
return true;
}
return false;
}
void QuicServerWorker::forwardNetworkData(
const folly::SocketAddress& client,
RoutingData&& routingData,
NetworkData&& networkData,
folly::Optional<QuicVersion> quicVersion,
bool isForwardedData) {
// if it's not Client initial or ZeroRtt, AND if the connectionId version
// mismatches: foward if pktForwarding is enabled else dropPacket
if (!routingData.isUsingClientConnId &&
!connIdAlgo_->canParse(routingData.destinationConnId)) {
if (packetForwardingEnabled_ && !isForwardedData) {
VLOG(3) << fmt::format(
"Forwarding packet with unknown connId version from client={} to another process, routingInfo={}",
client.describe(),
logRoutingInfo(routingData.destinationConnId));
auto recvTime = networkData.receiveTimePoint;
takeoverPktHandler_.forwardPacketToAnotherServer(
client, std::move(networkData).moveAllData(), recvTime);
QUIC_STATS(statsCallback_, onPacketForwarded);
return;
} else {
VLOG(3) << fmt::format(
"Dropping packet due to unknown connectionId version, routingInfo={}",
logRoutingInfo(routingData.destinationConnId));
QUIC_STATS(
statsCallback_,
onPacketDropped,
PacketDropReason::CONNECTION_NOT_FOUND);
}
return;
}
callback_->routeDataToWorker(
client,
std::move(routingData),
std::move(networkData),
std::move(quicVersion),
isForwardedData);
}
void QuicServerWorker::setPacingTimer(
TimerHighRes::SharedPtr pacingTimer) noexcept {
pacingTimer_ = std::move(pacingTimer);
}
void QuicServerWorker::dispatchPacketData(
const folly::SocketAddress& client,
RoutingData&& routingData,
NetworkData&& networkData,
folly::Optional<QuicVersion> quicVersion,
bool isForwardedData) noexcept {
DCHECK(socket_);
QuicServerTransport::Ptr transport;
bool dropPacket = false;
auto cit = connectionIdMap_.find(routingData.destinationConnId);
if (cit != connectionIdMap_.end()) {
transport = cit->second;
VLOG(10) << "Found existing connection for CID="
<< routingData.destinationConnId.hex() << " " << *transport;
} else if (routingData.headerForm != HeaderForm::Long) {
// Drop the packet if the header form is not long
VLOG(3) << fmt::format(
"Dropping non-long header packet with no connid match"
" headerForm={}, routingInfo={}",
static_cast<typename std::underlying_type<HeaderForm>::type>(
routingData.headerForm),
logRoutingInfo(routingData.destinationConnId));
// Try forwarding the packet to the old server (if it is enabled)
dropPacket = true;
}
bool cannotMakeTransport = false;
if (!dropPacket && !transport) {
// For LongHeader packets without existing associated connection, try to
// route with destinationConnId chosen by the peer and IP address of the
// peer.
CHECK(transportFactory_);
auto source = std::make_pair(client, routingData.destinationConnId);
auto sit = sourceAddressMap_.find(source);
if (sit == sourceAddressMap_.end()) {
// If it's a 0RTT packet and we have no CID, we probably lost the initial
// and want to buffer it for a while.
if (routingData.is0Rtt) {
auto itr = pending0RttData_.find(routingData.destinationConnId);
if (itr == pending0RttData_.end()) {
itr =
pending0RttData_.insert(routingData.destinationConnId, {}).first;
}
auto& vec = itr->second;
if (vec.size() != vec.max_size()) {
vec.emplace_back(std::move(networkData));
QUIC_STATS(statsCallback_, onZeroRttBuffered);
}
return;
} else if (!routingData.isInitial) {
VLOG(3) << fmt::format(
"Dropping packet from client={}, routingInfo={}",
client.describe(),
logRoutingInfo(routingData.destinationConnId));
dropPacket = true;
} else {
VLOG(4) << fmt::format(
"Creating new connection for client={}, routingInfo={}",
client.describe(),
logRoutingInfo(routingData.destinationConnId));
// This could be a new connection, add it in the map
// verify that the initial packet is at least min initial bytes
// to avoid amplification attacks.
if (networkData.totalData < kMinInitialPacketSize) {
// Don't even attempt to forward the packet, just drop it.
VLOG(3) << "Dropping small initial packet from client=" << client;
QUIC_STATS(
statsCallback_,
onPacketDropped,
PacketDropReason::INVALID_PACKET);
return;
}
// If there is a token present, decrypt it (could be either a retry
// token or a new token)
folly::io::Cursor cursor(networkData.packets.front().get());
auto maybeEncryptedToken = maybeGetEncryptedToken(cursor);
bool hasTokenSecret = transportSettings_.retryTokenSecret.hasValue();
// If the retryTokenSecret is not set, just skip evaluating validity of
// token and assume true
auto isValidRetryToken = !hasTokenSecret ||
(maybeEncryptedToken &&
validRetryToken(
*maybeEncryptedToken,
routingData.destinationConnId,
client.getIPAddress()));
auto isValidNewToken = !hasTokenSecret ||
(maybeEncryptedToken &&
validNewToken(*maybeEncryptedToken, client.getIPAddress()));
if (isValidNewToken) {
QUIC_STATS(statsCallback_, onNewTokenReceived);
} else if (maybeEncryptedToken && !isValidRetryToken) {
// Failed to decrypt the token as either a new or retry token
QUIC_STATS(statsCallback_, onTokenDecryptFailure);
}
// If rate-limiting is configured and there is no retry token,
// send a retry packet back to the client
if (!isValidRetryToken &&
((newConnRateLimiter_ &&
newConnRateLimiter_->check(networkData.receiveTimePoint)) ||
(unfinishedHandshakeLimitFn_.has_value() &&
globalUnfinishedHandshakes >=
(*unfinishedHandshakeLimitFn_)()))) {
if (hasTokenSecret) {
sendRetryPacket(
client,
routingData.destinationConnId,
routingData.sourceConnId.value_or(
ConnectionId(std::vector<uint8_t>())));
QUIC_STATS(statsCallback_, onConnectionRateLimited);
return;
} else {
VLOG(4)
<< "Not sending retry packet since retry token secret is not set";
}
}
// Check that we have a proper quic version before creating transport.
CHECK(quicVersion.has_value())
<< "no QUIC version supplied for transport creation";
// create 'accepting' transport
auto sock = makeSocket(getEventBase());
auto trans = transportFactory_->make(
getEventBase(), std::move(sock), client, quicVersion.value(), ctx_);
if (!trans) {
dropPacket = true;
cannotMakeTransport = true;
} else {
globalUnfinishedHandshakes++;
CHECK(trans);
if (transportSettings_.dataPathType ==
DataPathType::ContinuousMemory &&
bufAccessor_) {
trans->setBufAccessor(bufAccessor_.get());
}
trans->setPacingTimer(pacingTimer_);
trans->setRoutingCallback(this);
trans->setHandshakeFinishedCallback(this);
trans->setSupportedVersions(supportedVersions_);
trans->setOriginalPeerAddress(client);
#ifdef CCP_ENABLED
trans->setCcpDatapath(getCcpReader()->getDatapath());
#endif
trans->setCongestionControllerFactory(ccFactory_);
if (statsCallback_) {
trans->setTransportStatsCallback(statsCallback_.get());
}
auto overridenTransportSettings = transportSettingsOverrideFn_
? transportSettingsOverrideFn_(
transportSettings_, client.getIPAddress())
: folly::none;
if (overridenTransportSettings) {
if (overridenTransportSettings->dataPathType !=
transportSettings_.dataPathType) {
// It's too complex to support that.
LOG(ERROR)
<< "Overriding DataPathType isn't supported. Requested datapath="
<< (overridenTransportSettings->dataPathType ==
DataPathType::ContinuousMemory
? "ContinuousMemory"
: "ChainedMemory");
}
trans->setTransportSettings(*overridenTransportSettings);
} else {
trans->setTransportSettings(transportSettings_);
}
trans->setConnectionIdAlgo(connIdAlgo_.get());
trans->setServerConnectionIdRejector(this);
if (routingData.sourceConnId) {
trans->setClientConnectionId(*routingData.sourceConnId);
}
trans->setClientChosenDestConnectionId(routingData.destinationConnId);
// parameters to create server chosen connection id
ServerConnectionIdParams serverConnIdParams(
cidVersion_,
hostId_,
static_cast<uint8_t>(processId_),
workerId_);
trans->setServerConnectionIdParams(std::move(serverConnIdParams));
trans->accept();
auto result = sourceAddressMap_.emplace(std::make_pair(
std::make_pair(client, routingData.destinationConnId), trans));
if (!result.second) {
LOG(ERROR) << fmt::format(
"Routing entry already exists for client={}, routingInfo={}",
client.describe(),
logRoutingInfo(routingData.destinationConnId));
dropPacket = true;
} else {
for (const auto& observer : observerList_.getAll()) {
observer->accept(trans.get());
}
}
transport = trans;
}
}
} else {
transport = sit->second;
VLOG(4) << "Found existing connection for client=" << client << " "
<< *transport;
}
}
if (!dropPacket) {
DCHECK(transport->getEventBase()->isInEventBaseThread());
transport->onNetworkData(client, std::move(networkData));
// If we had pending 0RTT data for this DCID, process it.
if (routingData.isInitial && !pending0RttData_.empty()) {
auto itr = pending0RttData_.find(routingData.destinationConnId);
if (itr != pending0RttData_.end()) {
for (auto& data : itr->second) {
transport->onNetworkData(client, std::move(data));
}
pending0RttData_.erase(itr);
}
}
return;
}
if (cannotMakeTransport) {
VLOG(3)
<< "Dropping packet due to transport factory did not make transport";
QUIC_STATS(
statsCallback_,
onPacketDropped,
PacketDropReason::CANNOT_MAKE_TRANSPORT);
return;
}
if (!connIdAlgo_->canParse(routingData.destinationConnId)) {
VLOG(3) << "Dropping packet with bad DCID, routingInfo="
<< logRoutingInfo(routingData.destinationConnId);
QUIC_STATS(statsCallback_, onPacketDropped, PacketDropReason::PARSE_ERROR);
// TODO do we need to reset?
return;
}
auto connIdParam =
connIdAlgo_->parseConnectionId(routingData.destinationConnId);
if (connIdParam.hasError()) {
VLOG(3) << fmt::format(
"Dropping packet due to DCID parsing error={}, , errorCode={}, routingInfo={}",
connIdParam.error().what(),
folly::to<std::string>(connIdParam.error().errorCode()),
logRoutingInfo(routingData.destinationConnId));
QUIC_STATS(statsCallback_, onPacketDropped, PacketDropReason::PARSE_ERROR);
// TODO do we need to reset?
return;
}
if (connIdParam->hostId != hostId_) {
VLOG_EVERY_N(2, 100) << fmt::format(
"Dropping packet routed to wrong host, from client={}, routingInfo={},",
client.describe(),
logRoutingInfo(routingData.destinationConnId));
QUIC_STATS(
statsCallback_,
onPacketDropped,
PacketDropReason::ROUTING_ERROR_WRONG_HOST);
return sendResetPacket(
routingData.headerForm,
client,
networkData,
routingData.destinationConnId);
}
if (!packetForwardingEnabled_ || isForwardedData) {
QUIC_STATS(
statsCallback_,
onPacketDropped,
PacketDropReason::CONNECTION_NOT_FOUND);
return sendResetPacket(
routingData.headerForm,
client,
networkData,
routingData.destinationConnId);
}
// There's no existing connection for the packet's CID or the client's
// addr, and doesn't belong to the old server. Send a Reset.
if (connIdParam->processId == static_cast<uint8_t>(processId_)) {
QUIC_STATS(
statsCallback_,
onPacketDropped,
PacketDropReason::CONNECTION_NOT_FOUND);
return sendResetPacket(
routingData.headerForm,
client,
networkData,
routingData.destinationConnId);
}
// Optimistically route to another server
// if the packet type is not Initial and if there is not any connection
// associated with the given packet
VLOG(4) << fmt::format(
"Forwarding packet from client={} to another process, routingInfo={}",
client.describe(),
logRoutingInfo(routingData.destinationConnId));
auto recvTime = networkData.receiveTimePoint;
takeoverPktHandler_.forwardPacketToAnotherServer(
client, std::move(networkData).moveAllData(), recvTime);
QUIC_STATS(statsCallback_, onPacketForwarded);
}
void QuicServerWorker::sendResetPacket(
const HeaderForm& headerForm,
const folly::SocketAddress& client,
const NetworkData& networkData,
const ConnectionId& connId) {
if (headerForm != HeaderForm::Short) {
// Only send resets in response to short header packets.
return;
}
auto packetSize = networkData.totalData;
auto resetSize = std::min<uint16_t>(packetSize, kDefaultMaxUDPPayload);
// Per the spec, less than 43 we should respond with packet size - 1.
if (packetSize < 43) {
resetSize = std::max<uint16_t>(packetSize - 1, kMinStatelessPacketSize);
} else {
resetSize = std::max<uint16_t>(
folly::Random::secureRand32() % resetSize, kMinStatelessPacketSize);
}
CHECK(transportSettings_.statelessResetTokenSecret.has_value());
StatelessResetGenerator generator(
*transportSettings_.statelessResetTokenSecret,
getAddress().getFullyQualified());
StatelessResetToken token = generator.generateToken(connId);
StatelessResetPacketBuilder builder(resetSize, token);
auto resetData = std::move(builder).buildPacket();
auto resetDataLen = resetData->computeChainDataLength();
socket_->write(client, std::move(resetData));
QUIC_STATS(statsCallback_, onWrite, resetDataLen);
QUIC_STATS(statsCallback_, onPacketSent);
QUIC_STATS(statsCallback_, onStatelessReset);
}
folly::Optional<std::string> QuicServerWorker::maybeGetEncryptedToken(
folly::io::Cursor& cursor) {
// Move cursor to the byte right after the initial byte
if (!cursor.canAdvance(1)) {
return folly::none;
}
auto initialByte = cursor.readBE<uint8_t>();
// We already know this is an initial packet, which uses a long header
auto parsedLongHeader = parseLongHeader(initialByte, cursor);
if (!parsedLongHeader || !parsedLongHeader->parsedLongHeader.has_value()) {
return folly::none;
}
auto header = parsedLongHeader->parsedLongHeader.value().header;
if (!header.hasToken()) {
return folly::none;
}
return header.getToken();
}
/**
* Helper method to calculate the delta between nowInMs and the time the token
* was issued. This delta is compared against the max lifetime of the token
* (e.g. 1 day for new tokens and 5 min for retry tokens) to determine
* validity.
*/
bool checkTokenAge(uint64_t tokenIssuedMs, uint64_t kTokenValidMs) {
uint64_t nowInMs = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
// Retry timestamps can also come from the future as the system clock can
// move both forwards and backwards due to it being synchronized by NTP
auto tokenAgeMs = nowInMs > tokenIssuedMs ? nowInMs - tokenIssuedMs
: tokenIssuedMs - nowInMs;
return tokenAgeMs <= kTokenValidMs;
}
bool QuicServerWorker::validRetryToken(
std::string& encryptedToken,
const ConnectionId& dstConnId,
const folly::IPAddress& clientIp) {
CHECK(transportSettings_.retryTokenSecret.hasValue());
TokenGenerator tokenGenerator(transportSettings_.retryTokenSecret.value());
// Create a psuedo token to generate the assoc data.
RetryToken token(dstConnId, clientIp, 0);
auto maybeDecryptedRetryTokenMs = tokenGenerator.decryptToken(
folly::IOBuf::copyBuffer(encryptedToken), token.genAeadAssocData());
return maybeDecryptedRetryTokenMs &&
checkTokenAge(maybeDecryptedRetryTokenMs, kMaxRetryTokenValidMs);
}
bool QuicServerWorker::validNewToken(
std::string& encryptedToken,
const folly::IPAddress& clientIp) {
CHECK(transportSettings_.retryTokenSecret.hasValue());
TokenGenerator tokenGenerator(transportSettings_.retryTokenSecret.value());
// Create a psuedo token to generate the assoc data.
NewToken token(clientIp);
auto maybeDecryptedNewTokenMs = tokenGenerator.decryptToken(
folly::IOBuf::copyBuffer(encryptedToken), token.genAeadAssocData());
return maybeDecryptedNewTokenMs &&
checkTokenAge(maybeDecryptedNewTokenMs, kMaxNewTokenValidMs);
}
void QuicServerWorker::sendRetryPacket(
const folly::SocketAddress& client,
const ConnectionId& dstConnId,
const ConnectionId& srcConnId) {
// Create the encrypted retry token
TokenGenerator generator(transportSettings_.retryTokenSecret.value());
// RetryToken defaults to currentTimeInMs
RetryToken retryToken(dstConnId, client.getIPAddress(), client.getPort());
auto encryptedToken = generator.encryptToken(retryToken);
CHECK(encryptedToken.has_value());
std::string encryptedTokenStr =
encryptedToken.value()->moveToFbString().toStdString();
// Create the integrity tag
// For the tag to be correctly validated by the client, the initalByte
// needs to match the initialByte in the retry packet
uint8_t initialByte = kHeaderFormMask | LongHeader::kFixedBitMask |
(static_cast<uint8_t>(LongHeader::Types::Retry)
<< LongHeader::kTypeShift);
// Flip the src conn ID and dst conn ID as per section 7.3 of QUIC draft
// for both pseudo retry builder and the actual retry packet builder
PseudoRetryPacketBuilder pseudoBuilder(
initialByte,
dstConnId, /* src conn id */
srcConnId, /* dst conn id */
dstConnId, /* orginal dst conn id */
QuicVersion::MVFST_INVALID,
folly::IOBuf::copyBuffer(encryptedTokenStr));
Buf pseudoRetryPacketBuf = std::move(pseudoBuilder).buildPacket();
FizzRetryIntegrityTagGenerator fizzRetryIntegrityTagGenerator;
auto integrityTag = fizzRetryIntegrityTagGenerator.getRetryIntegrityTag(
QuicVersion::MVFST_INVALID, pseudoRetryPacketBuf.get());
// Create the actual retry packet
RetryPacketBuilder builder(
dstConnId, /* src conn id */
srcConnId, /* dst conn id */
QuicVersion::MVFST_INVALID,
std::move(encryptedTokenStr),
std::move(integrityTag));
auto retryData = std::move(builder).buildPacket();
auto retryDataLen = retryData->computeChainDataLength();
socket_->write(client, retryData);
QUIC_STATS(statsCallback_, onWrite, retryDataLen);
QUIC_STATS(statsCallback_, onPacketSent);
}
void QuicServerWorker::allowBeingTakenOver(
std::unique_ptr<folly::AsyncUDPSocket> socket,
const folly::SocketAddress& address) {
DCHECK(!takeoverCB_);
// We instantiate and bind the TakeoverHandlerCallback to the given address.
// It is reset at shutdownAllConnections (i.e. only when the process dies).
takeoverCB_ = std::make_unique<TakeoverHandlerCallback>(
this, takeoverPktHandler_, transportSettings_, std::move(socket));
takeoverCB_->bind(address);
}
const folly::SocketAddress& QuicServerWorker::overrideTakeoverHandlerAddress(
std::unique_ptr<folly::AsyncUDPSocket> socket,
const folly::SocketAddress& address) {
CHECK(takeoverCB_);
takeoverCB_->rebind(std::move(socket), address);
return takeoverCB_->getAddress();
}
void QuicServerWorker::startPacketForwarding(
const folly::SocketAddress& destAddr) {
packetForwardingEnabled_ = true;
takeoverPktHandler_.setDestination(destAddr);
}
void QuicServerWorker::stopPacketForwarding() {
packetForwardingEnabled_ = false;
takeoverPktHandler_.stop();
}
void QuicServerWorker::onReadError(
const folly::AsyncSocketException& ex) noexcept {
VLOG(4) << "QuicServer readerr: " << ex.what();
if (!callback_) {
VLOG(0) << "Worker callback is null. Ignoring worker error.";
return;
}
callback_->handleWorkerError(LocalErrorCode::INTERNAL_ERROR);
}
void QuicServerWorker::onReadClosed() noexcept {
shutdownAllConnections(LocalErrorCode::SHUTTING_DOWN);
}
int QuicServerWorker::getTakeoverHandlerSocketFD() {
CHECK(takeoverCB_);
return takeoverCB_->getSocketFD();
}
TakeoverProtocolVersion QuicServerWorker::getTakeoverProtocolVersion()
const noexcept {
return takeoverPktHandler_.getTakeoverProtocolVersion();
}
void QuicServerWorker::setProcessId(enum ProcessId id) noexcept {
processId_ = id;
}
ProcessId QuicServerWorker::getProcessId() const noexcept {
return processId_;
}
void QuicServerWorker::setWorkerId(uint8_t id) noexcept {
workerId_ = id;
}
uint8_t QuicServerWorker::getWorkerId() const noexcept {
return workerId_;
}
void QuicServerWorker::setHostId(uint32_t hostId) noexcept {
hostId_ = hostId;
}
void QuicServerWorker::setConnectionIdVersion(
ConnectionIdVersion cidVersion) noexcept {
cidVersion_ = cidVersion;
}
CCPReader* QuicServerWorker::getCcpReader() const noexcept {
return ccpReader_.get();
}
void QuicServerWorker::setNewConnectionSocketFactory(
QuicUDPSocketFactory* factory) {
socketFactory_ = factory;
takeoverPktHandler_.setSocketFactory(socketFactory_);
}
void QuicServerWorker::setTransportFactory(
QuicServerTransportFactory* factory) {
transportFactory_ = factory;
}
void QuicServerWorker::setSupportedVersions(
const std::vector<QuicVersion>& supportedVersions) {
supportedVersions_ = supportedVersions;
}
void QuicServerWorker::setFizzContext(
std::shared_ptr<const fizz::server::FizzServerContext> ctx) {
ctx_ = ctx;
}
void QuicServerWorker::setTransportSettings(
TransportSettings transportSettings) {
transportSettings_ = transportSettings;
if (transportSettings_.batchingMode != QuicBatchingMode::BATCHING_MODE_GSO) {
if (transportSettings_.dataPathType == DataPathType::ContinuousMemory) {
LOG(ERROR) << "Unsupported data path type and batching mode combination";
}
transportSettings_.dataPathType = DataPathType::ChainedMemory;
}
if (transportSettings_.dataPathType == DataPathType::ContinuousMemory) {
// TODO: maxBatchSize is only a good start value when each transport does
// its own socket writing. If we experiment with multiple transports GSO
// together, we will need a better value.
bufAccessor_ = std::make_unique<SimpleBufAccessor>(
kDefaultMaxUDPPayload * transportSettings_.maxBatchSize);
VLOG(10) << "GSO write buf accessor created for ContinuousMemory data path";
}
}
void QuicServerWorker::rejectNewConnections(
std::function<bool()> rejectNewConnections) {
rejectNewConnections_ = std::move(rejectNewConnections);
}
void QuicServerWorker::setIsBlockListedSrcPort(
std::function<bool(uint16_t)> isBlockListedSrcPort) {
isBlockListedSrcPort_ = std::move(isBlockListedSrcPort);
}
void QuicServerWorker::setHealthCheckToken(
const std::string& healthCheckToken) {
healthCheckToken_ = folly::IOBuf::copyBuffer(healthCheckToken);
}
std::unique_ptr<folly::AsyncUDPSocket> QuicServerWorker::makeSocket(
folly::EventBase* evb) const {
CHECK(socket_);
return socketFactory_->make(evb, socket_->getNetworkSocket().toFd());
}
std::unique_ptr<folly::AsyncUDPSocket> QuicServerWorker::makeSocket(
folly::EventBase* evb,
int fd) const {
return socketFactory_->make(evb, fd);
}
const QuicServerWorker::ConnIdToTransportMap&
QuicServerWorker::getConnectionIdMap() const {
return connectionIdMap_;
}
const QuicServerWorker::SrcToTransportMap&
QuicServerWorker::getSrcToTransportMap() const {
return sourceAddressMap_;
}
void QuicServerWorker::onConnectionIdAvailable(
QuicServerTransport::Ptr transport,
ConnectionId id) noexcept {
VLOG(4) << "Adding into connectionIdMap_ for CID=" << id << " " << *transport;
QuicServerTransport* transportPtr = transport.get();
std::weak_ptr<QuicServerTransport> weakTransport = transport;
auto result =
connectionIdMap_.emplace(std::make_pair(id, std::move(transport)));
if (!result.second) {
// In the case of duplicates, log if they represent the same transport,
// or different ones.
auto it = result.first;
QuicServerTransport* existingTransportPtr = it->second.get();
LOG(ERROR) << "connectionIdMap_ already has CID=" << id
<< " Is same transport: "
<< (existingTransportPtr == transportPtr);
} else if (boundServerTransports_.emplace(transportPtr, weakTransport)
.second) {
QUIC_STATS(statsCallback_, onNewConnection);
}
}
void QuicServerWorker::onConnectionIdBound(
QuicServerTransport::Ptr transport) noexcept {
auto clientInitialDestCid = transport->getClientChosenDestConnectionId();
CHECK(clientInitialDestCid);
auto source = std::make_pair(
transport->getOriginalPeerAddress(), *clientInitialDestCid);
VLOG(4) << "Removing from sourceAddressMap_ address=" << source.first;
auto iter = sourceAddressMap_.find(source);
if (iter == sourceAddressMap_.end() || iter->second != transport) {
LOG(ERROR) << "Transport not match, client=" << *transport;
} else {
sourceAddressMap_.erase(source);
}
}
void QuicServerWorker::onConnectionUnbound(
QuicServerTransport* transport,
const QuicServerTransport::SourceIdentity& source,
const std::vector<ConnectionIdData>& connectionIdData) noexcept {
VLOG(4) << "Removing from sourceAddressMap_ address=" << source.first;
auto& localConnectionError = transport->getState()->localConnectionError;
if (transport->getConnectionsStats().totalBytesSent == 0 &&
!(localConnectionError && localConnectionError->code.asLocalErrorCode() &&
*localConnectionError->code.asLocalErrorCode() ==
LocalErrorCode::CONNECTION_ABANDONED)) {
QUIC_STATS(statsCallback_, onConnectionCloseZeroBytesWritten);
}
// Ensures we only process `onConnectionUnbound()` once.
transport->setRoutingCallback(nullptr);
boundServerTransports_.erase(transport);
for (auto& connId : connectionIdData) {
VLOG(4) << fmt::format(
"Removing CID from connectionIdMap_, routingInfo={}",
logRoutingInfo(connId.connId));
auto it = connectionIdMap_.find(connId.connId);
// This should be nullptr in most cases. In order to investigate if
// an incorrect server transport is removed, this will be set to the value
// of the incorrect transport, to see if boundServerTransports_ will
// still hold a pointer to the incorrect transport.
QuicServerTransport* incorrectTransportPtr = nullptr;
if (it == connectionIdMap_.end()) {
VLOG(3) << "CID not found in connectionIdMap_ CID= " << connId.connId;
} else {
QuicServerTransport* existingPtr = it->second.get();
if (existingPtr != transport) {
LOG(ERROR) << "Incorrect transport being removed for duplicate CID="
<< connId.connId;
incorrectTransportPtr = existingPtr;
}
}
connectionIdMap_.erase(connId.connId);
if (incorrectTransportPtr != nullptr) {
if (boundServerTransports_.find(incorrectTransportPtr) !=
boundServerTransports_.end()) {
LOG(ERROR)
<< "boundServerTransports_ contains deleted transport for duplicate CID="
<< connId.connId;
}
}
}
sourceAddressMap_.erase(source);
}
void QuicServerWorker::onHandshakeFinished() noexcept {
CHECK_GE(--globalUnfinishedHandshakes, 0);
}
void QuicServerWorker::onHandshakeUnfinished() noexcept {
CHECK_GE(--globalUnfinishedHandshakes, 0);
}
void QuicServerWorker::shutdownAllConnections(LocalErrorCode error) {
VLOG(4) << "QuicServer shutdown all connections."
<< " addressMap=" << sourceAddressMap_.size()
<< " connectionIdMap=" << connectionIdMap_.size();
if (shutdown_) {
return;
}
shutdown_ = true;
if (socket_) {
socket_->pauseRead();
}
if (takeoverCB_) {
takeoverCB_->pause();
}
callback_ = nullptr;
// Shut down all transports without bound connection ids.
for (auto& it : sourceAddressMap_) {
auto transport = it.second;
transport->setRoutingCallback(nullptr);
transport->setTransportStatsCallback(nullptr);
transport->setHandshakeFinishedCallback(nullptr);
transport->closeNow(
QuicError(QuicErrorCode(error), std::string("shutting down")));
}
// Shut down all transports with bound connection ids.
for (auto transport : boundServerTransports_) {
if (auto t = transport.second.lock()) {
t->setRoutingCallback(nullptr);
t->setTransportStatsCallback(nullptr);
t->setHandshakeFinishedCallback(nullptr);
t->closeNow(
QuicError(QuicErrorCode(error), std::string("shutting down")));
}
}
sourceAddressMap_.clear();
connectionIdMap_.clear();
takeoverPktHandler_.stop();
if (statsCallback_) {
statsCallback_.reset();
}
socket_.reset();
takeoverCB_.reset();
pacingTimer_.reset();
}
QuicServerWorker::~QuicServerWorker() {
shutdownAllConnections(LocalErrorCode::SHUTTING_DOWN);
}
bool QuicServerWorker::rejectConnectionId(
const ConnectionId& candidate) const noexcept {
return connectionIdMap_.find(candidate) != connectionIdMap_.end();
}
std::string QuicServerWorker::logRoutingInfo(const ConnectionId& connId) const {
constexpr auto base =
"CID={}, cidVersion={}, workerId={}, processId={}, hostId={}, threadId={}, ";
if (!connIdAlgo_->canParse(connId)) {
return fmt::format(
base,
connId.hex(),
(uint32_t)cidVersion_,
(uint32_t)workerId_,
(uint32_t)processId_,
(uint32_t)hostId_,
folly::getCurrentThreadID());
}
auto connIdParam = connIdAlgo_->parseConnectionId(connId);
if (connIdParam.hasError()) {
return fmt::format(
base,
connId.hex(),
(uint32_t)cidVersion_,
(uint32_t)workerId_,
(uint32_t)processId_,
(uint32_t)hostId_,
folly::getCurrentThreadID());
}
std::string extended = std::string(base) +
"cidVersion in packet={}, workerId in packet={}, processId in packet={}, hostId in packet={}, ";
return fmt::vformat(
extended,
fmt::make_format_args(
connId.hex(),
(uint32_t)cidVersion_,
(uint32_t)workerId_,
(uint32_t)processId_,
(uint32_t)hostId_,
folly::getCurrentThreadID(),
(uint32_t)connIdParam->version,
(uint32_t)connIdParam->workerId,
(uint32_t)connIdParam->processId,
(uint32_t)connIdParam->hostId));
}
QuicServerWorker::AcceptObserverList::AcceptObserverList(
QuicServerWorker* worker)
: worker_(worker) {}
QuicServerWorker::AcceptObserverList::~AcceptObserverList() {
for (const auto& cb : observers_) {
cb->acceptorDestroy(worker_);
}
}
void QuicServerWorker::AcceptObserverList::add(AcceptObserver* observer) {
// adding the same observer multiple times is not allowed
CHECK(
std::find(observers_.begin(), observers_.end(), observer) ==
observers_.end());
observers_.emplace_back(CHECK_NOTNULL(observer));
observer->observerAttach(worker_);
}
bool QuicServerWorker::AcceptObserverList::remove(AcceptObserver* observer) {
auto it = std::find(observers_.begin(), observers_.end(), observer);
if (it == observers_.end()) {
return false;
}
observer->observerDetach(worker_);
observers_.erase(it);
return true;
}
void QuicServerWorker::getAllConnectionsStats(
std::vector<QuicConnectionStats>& stats) {
folly::F14FastMap<QuicServerTransport::Ptr, uint32_t> uniqueConns;
for (const auto& conn : connectionIdMap_) {
if (!conn.second) {
continue;
}
auto connState =
static_cast<const QuicServerConnectionState*>(conn.second->getState());
if (!connState) {
continue;
}
uniqueConns[conn.second]++;
}
stats.reserve(stats.size() + uniqueConns.size());
for (const auto& connEntry : uniqueConns) {
QuicConnectionStats connStats = connEntry.first->getConnectionsStats();
connStats.workerID = workerId_;
connStats.numConnIDs = connEntry.second;
stats.emplace_back(connStats);
}
}
} // namespace quic