quic/server/handshake/ServerHandshake.cpp (393 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <quic/server/handshake/ServerHandshake.h>
#include <quic/fizz/handshake/FizzBridge.h>
#include <quic/fizz/handshake/FizzCryptoFactory.h>
#include <quic/state/QuicStreamFunctions.h>
namespace quic {
ServerHandshake::ServerHandshake(QuicConnectionStateBase* conn)
: conn_(conn), actionGuard_(nullptr), cryptoState_(*conn->cryptoState) {}
void ServerHandshake::accept(
std::shared_ptr<ServerTransportParametersExtension> transportParams) {
SCOPE_EXIT {
inHandshakeStack_ = false;
};
transportParams_ = std::move(transportParams);
inHandshakeStack_ = true;
processAccept();
}
void ServerHandshake::initialize(
folly::Executor* executor,
HandshakeCallback* callback,
std::unique_ptr<fizz::server::AppTokenValidator> validator) {
executor_ = executor;
initializeImpl(callback, std::move(validator));
}
void ServerHandshake::doHandshake(
std::unique_ptr<folly::IOBuf> data,
EncryptionLevel encryptionLevel) {
SCOPE_EXIT {
inHandshakeStack_ = false;
};
inHandshakeStack_ = true;
waitForData_ = false;
switch (encryptionLevel) {
case EncryptionLevel::Initial:
initialReadBuf_.append(std::move(data));
break;
case EncryptionLevel::Handshake:
handshakeReadBuf_.append(std::move(data));
break;
case EncryptionLevel::EarlyData:
case EncryptionLevel::AppData:
appDataReadBuf_.append(std::move(data));
break;
default:
LOG(FATAL) << "Unhandled EncryptionLevel";
}
processPendingEvents();
if (error_) {
throw QuicTransportException(error_->first, error_->second);
}
}
void ServerHandshake::writeNewSessionTicket(const AppToken& appToken) {
SCOPE_EXIT {
inHandshakeStack_ = false;
};
inHandshakeStack_ = true;
writeNewSessionTicketToCrypto(appToken);
processPendingEvents();
if (error_) {
throw QuicTransportException(error_->first, error_->second);
}
}
std::unique_ptr<Aead> ServerHandshake::getHandshakeReadCipher() {
if (error_) {
throw QuicTransportException(error_->first, error_->second);
}
return std::move(handshakeReadCipher_);
}
std::unique_ptr<Aead> ServerHandshake::getOneRttWriteCipher() {
if (error_) {
throw QuicTransportException(error_->first, error_->second);
}
return std::move(oneRttWriteCipher_);
}
std::unique_ptr<Aead> ServerHandshake::getOneRttReadCipher() {
if (error_) {
throw QuicTransportException(error_->first, error_->second);
}
return std::move(oneRttReadCipher_);
}
std::unique_ptr<Aead> ServerHandshake::getZeroRttReadCipher() {
if (error_) {
throw QuicTransportException(error_->first, error_->second);
}
return std::move(zeroRttReadCipher_);
}
std::unique_ptr<PacketNumberCipher>
ServerHandshake::getOneRttReadHeaderCipher() {
if (error_) {
throw QuicTransportException(error_->first, error_->second);
}
return std::move(oneRttReadHeaderCipher_);
}
std::unique_ptr<PacketNumberCipher>
ServerHandshake::getOneRttWriteHeaderCipher() {
if (error_) {
throw QuicTransportException(error_->first, error_->second);
}
return std::move(oneRttWriteHeaderCipher_);
}
std::unique_ptr<PacketNumberCipher>
ServerHandshake::getHandshakeReadHeaderCipher() {
if (error_) {
throw QuicTransportException(error_->first, error_->second);
}
return std::move(handshakeReadHeaderCipher_);
}
std::unique_ptr<PacketNumberCipher>
ServerHandshake::getZeroRttReadHeaderCipher() {
if (error_) {
throw QuicTransportException(error_->first, error_->second);
}
return std::move(zeroRttReadHeaderCipher_);
}
/**
* The application will not get any more callbacks from the handshake layer
* after this method returns.
*/
void ServerHandshake::cancel() {
callback_ = nullptr;
}
ServerHandshake::Phase ServerHandshake::getPhase() const {
return phase_;
}
folly::Optional<ClientTransportParameters>
ServerHandshake::getClientTransportParams() {
return transportParams_->getClientTransportParams();
}
bool ServerHandshake::isHandshakeDone() {
return handshakeDone_;
}
const fizz::server::State& ServerHandshake::getState() const {
return state_;
}
const folly::Optional<std::string>& ServerHandshake::getApplicationProtocol()
const {
return state_.alpn();
}
void ServerHandshake::onError(
std::pair<std::string, TransportErrorCode> error) {
VLOG(10) << "ServerHandshake error " << error.first;
error_ = error;
handshakeEventAvailable_ = true;
}
void ServerHandshake::onWriteData(fizz::WriteToSocket& write) {
if (!callback_) {
// We've been canceled, just return. If we're canceled it's possible that
// cryptoState_ has been deleted, so let's not refer to it.
return;
}
for (auto& content : write.contents) {
auto encryptionLevel = getEncryptionLevelFromFizz(content.encryptionLevel);
CHECK(encryptionLevel != EncryptionLevel::EarlyData)
<< "Server cannot write early data";
if (content.contentType != fizz::ContentType::handshake) {
continue;
}
auto cryptoStream = getCryptoStream(cryptoState_, encryptionLevel);
writeDataToQuicStream(*cryptoStream, std::move(content.data));
}
handshakeEventAvailable_ = true;
}
void ServerHandshake::onHandshakeDone() {
handshakeEventAvailable_ = true;
}
void ServerHandshake::addProcessingActions(fizz::server::AsyncActions actions) {
if (actionGuard_) {
onError(std::make_pair(
"Processing action while pending", TransportErrorCode::INTERNAL_ERROR));
return;
}
actionGuard_ = folly::DelayedDestruction::DestructorGuard(conn_);
startActions(std::move(actions));
}
void ServerHandshake::startActions(fizz::server::AsyncActions actions) {
folly::variant_match(
actions,
[this](folly::SemiFuture<fizz::server::Actions>& futureActions) {
std::move(futureActions)
.via(executor_)
.then(&ServerHandshake::processActions, this);
},
[this](fizz::server::Actions& immediateActions) {
this->processActions(std::move(immediateActions));
});
}
void ServerHandshake::processPendingEvents() {
if (inProcessPendingEvents_) {
return;
}
folly::DelayedDestruction::DestructorGuard dg(conn_);
inProcessPendingEvents_ = true;
SCOPE_EXIT {
inProcessPendingEvents_ = false;
};
while (!actionGuard_ && !error_) {
actionGuard_ = folly::DelayedDestruction::DestructorGuard(conn_);
if (!waitForData_) {
switch (getReadRecordLayerEncryptionLevel()) {
case EncryptionLevel::Initial:
processSocketData(initialReadBuf_);
break;
case EncryptionLevel::Handshake:
processSocketData(handshakeReadBuf_);
break;
case EncryptionLevel::EarlyData:
case EncryptionLevel::AppData:
// TODO: Get rid of appDataReadBuf_ once we do not need EndOfEarlyData
// any more.
processSocketData(appDataReadBuf_);
break;
default:
LOG(FATAL) << "Unhandled EncryptionLevel";
}
} else if (!processPendingCryptoEvent()) {
actionGuard_ = folly::DelayedDestruction::DestructorGuard(nullptr);
return;
}
}
}
class ServerHandshake::ActionMoveVisitor : public boost::static_visitor<> {
public:
explicit ActionMoveVisitor(ServerHandshake& server) : server_(server) {}
void operator()(fizz::DeliverAppData&) {
server_.onError(std::make_pair(
"Unexpected data on crypto stream",
TransportErrorCode::PROTOCOL_VIOLATION));
}
void operator()(fizz::WriteToSocket& write) {
server_.onWriteData(write);
}
void operator()(fizz::server::ReportEarlyHandshakeSuccess&) {
server_.phase_ = Phase::KeysDerived;
}
void operator()(fizz::server::ReportHandshakeSuccess&) {
server_.handshakeDone_ = true;
auto originalPhase = server_.phase_;
// Fizz only reports handshake success when the server receives the full
// client finished. At this point we can write any post handshake data and
// crypto data with the 1-rtt keys.
server_.phase_ = Phase::Established;
if (originalPhase != Phase::Handshake) {
// We already derived the zero rtt keys as well as the one rtt write
// keys.
server_.onHandshakeDone();
}
}
void operator()(fizz::ReportError& err) {
auto errMsg = err.error.what();
if (errMsg.empty()) {
errMsg = "Error during handshake";
}
auto fe = err.error.get_exception<fizz::FizzException>();
if (fe && fe->getAlert()) {
auto alertNum =
static_cast<std::underlying_type<TransportErrorCode>::type>(
fe->getAlert().value());
alertNum += static_cast<std::underlying_type<TransportErrorCode>::type>(
TransportErrorCode::CRYPTO_ERROR);
server_.onError(std::make_pair(
errMsg.toStdString(), static_cast<TransportErrorCode>(alertNum)));
} else {
server_.onError(std::make_pair(
errMsg.toStdString(),
static_cast<TransportErrorCode>(
fizz::AlertDescription::internal_error)));
}
}
void operator()(fizz::WaitForData&) {
server_.waitForData_ = true;
}
void operator()(fizz::server::MutateState& mutator) {
mutator(server_.state_);
}
void operator()(fizz::server::AttemptVersionFallback&) {
CHECK(false) << "Fallback Unexpected";
}
void operator()(fizz::EndOfData&) {
server_.onError(std::make_pair(
"Unexpected close notify received",
TransportErrorCode::INTERNAL_ERROR));
}
void operator()(fizz::SecretAvailable& secretAvailable) {
switch (secretAvailable.secret.type.type()) {
case fizz::SecretType::Type::EarlySecrets_E:
switch (*secretAvailable.secret.type.asEarlySecrets()) {
case fizz::EarlySecrets::ClientEarlyTraffic:
server_.computeCiphers(
CipherKind::ZeroRttRead,
folly::range(secretAvailable.secret.secret));
break;
default:
break;
}
break;
case fizz::SecretType::Type::HandshakeSecrets_E:
switch (*secretAvailable.secret.type.asHandshakeSecrets()) {
case fizz::HandshakeSecrets::ClientHandshakeTraffic:
server_.computeCiphers(
CipherKind::HandshakeRead,
folly::range(secretAvailable.secret.secret));
break;
case fizz::HandshakeSecrets::ServerHandshakeTraffic:
server_.computeCiphers(
CipherKind::HandshakeWrite,
folly::range(secretAvailable.secret.secret));
break;
}
break;
case fizz::SecretType::Type::AppTrafficSecrets_E:
switch (*secretAvailable.secret.type.asAppTrafficSecrets()) {
case fizz::AppTrafficSecrets::ClientAppTraffic:
server_.computeCiphers(
CipherKind::OneRttRead,
folly::range(secretAvailable.secret.secret));
break;
case fizz::AppTrafficSecrets::ServerAppTraffic:
server_.computeCiphers(
CipherKind::OneRttWrite,
folly::range(secretAvailable.secret.secret));
break;
}
break;
case fizz::SecretType::Type::MasterSecrets_E:
break;
}
server_.handshakeEventAvailable_ = true;
}
private:
ServerHandshake& server_;
};
void ServerHandshake::processActions(
fizz::server::ServerStateMachine::CompletedActions actions) {
// This extra DestructorGuard is needed due to the gap between clearing
// actionGuard_ and potentially processing another action.
folly::DelayedDestruction::DestructorGuard dg(conn_);
ActionMoveVisitor visitor(*this);
for (auto& action : actions) {
switch (action.type()) {
case fizz::server::Action::Type::DeliverAppData_E:
visitor(*action.asDeliverAppData());
break;
case fizz::server::Action::Type::WriteToSocket_E:
visitor(*action.asWriteToSocket());
break;
case fizz::server::Action::Type::ReportHandshakeSuccess_E:
visitor(*action.asReportHandshakeSuccess());
break;
case fizz::server::Action::Type::ReportEarlyHandshakeSuccess_E:
visitor(*action.asReportEarlyHandshakeSuccess());
break;
case fizz::server::Action::Type::ReportError_E:
visitor(*action.asReportError());
break;
case fizz::server::Action::Type::EndOfData_E:
visitor(*action.asEndOfData());
break;
case fizz::server::Action::Type::MutateState_E:
visitor(*action.asMutateState());
break;
case fizz::server::Action::Type::WaitForData_E:
visitor(*action.asWaitForData());
break;
case fizz::server::Action::Type::AttemptVersionFallback_E:
visitor(*action.asAttemptVersionFallback());
break;
case fizz::server::Action::Type::SecretAvailable_E:
visitor(*action.asSecretAvailable());
break;
}
}
actionGuard_ = folly::DelayedDestruction::DestructorGuard(nullptr);
if (callback_ && !inHandshakeStack_ && handshakeEventAvailable_) {
callback_->onCryptoEventAvailable();
}
handshakeEventAvailable_ = false;
processPendingEvents();
}
void ServerHandshake::computeCiphers(CipherKind kind, folly::ByteRange secret) {
std::unique_ptr<Aead> aead;
std::unique_ptr<PacketNumberCipher> headerCipher;
std::tie(aead, headerCipher) = buildCiphers(secret);
switch (kind) {
case CipherKind::HandshakeRead:
handshakeReadCipher_ = std::move(aead);
handshakeReadHeaderCipher_ = std::move(headerCipher);
break;
case CipherKind::HandshakeWrite:
conn_->handshakeWriteCipher = std::move(aead);
conn_->handshakeWriteHeaderCipher = std::move(headerCipher);
break;
case CipherKind::OneRttRead:
oneRttReadCipher_ = std::move(aead);
oneRttReadHeaderCipher_ = std::move(headerCipher);
break;
case CipherKind::OneRttWrite:
oneRttWriteCipher_ = std::move(aead);
oneRttWriteHeaderCipher_ = std::move(headerCipher);
break;
case CipherKind::ZeroRttRead:
zeroRttReadCipher_ = std::move(aead);
zeroRttReadHeaderCipher_ = std::move(headerCipher);
break;
default:
folly::assume_unreachable();
}
handshakeEventAvailable_ = true;
}
} // namespace quic