fizz/protocol/AsyncFizzBase.cpp (465 lines of code) (raw):
/*
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <fizz/protocol/AsyncFizzBase.h>
#include <folly/Conv.h>
#include <folly/io/Cursor.h>
namespace fizz {
using folly::AsyncSocketException;
/**
* Min and max read buffer sizes when using non-movable buffer.
*/
static const uint32_t kMinReadSize = 1460;
static const uint32_t kMaxReadSize = 4000;
/**
* Buffer size above which we should unset our read callback to apply back
* pressure on the transport.
*/
static const uint32_t kMaxBufSize = 64 * 1024;
/**
* Buffer size above which we should break up shared writes, to avoid storing
* entire unencrypted and encrypted buffer simultaneously.
*/
static const uint32_t kPartialWriteThreshold = 128 * 1024;
AsyncFizzBase::AsyncFizzBase(
folly::AsyncTransportWrapper::UniquePtr transport,
TransportOptions options)
: folly::WriteChainAsyncTransportWrapper<folly::AsyncTransportWrapper>(
std::move(transport)),
handshakeTimeout_(*this, transport_->getEventBase()),
transportOptions_(std::move(options)),
ioVecQueue_(folly::IOBufIovecBuilder::Options().setBlockSize(
transportOptions_.readVecBlockSize)) {
setReadMode(transportOptions_.readMode);
}
AsyncFizzBase::~AsyncFizzBase() {
transport_->setEventCallback(nullptr);
transport_->setReadCB(nullptr);
if (tailWriteRequest_) {
tailWriteRequest_->unlinkFromBase();
}
}
void AsyncFizzBase::destroy() {
transport_->closeNow();
transport_->setEventCallback(nullptr);
transport_->setReadCB(nullptr);
DelayedDestruction::destroy();
}
AsyncFizzBase::ReadCallback* AsyncFizzBase::getReadCallback() const {
return readCallback_;
}
void AsyncFizzBase::setReadCB(AsyncFizzBase::ReadCallback* callback) {
readCallback_ = callback;
if (readCallback_) {
if (appDataBuf_) {
deliverAppData(nullptr);
}
if (!good()) {
AsyncSocketException ex(
AsyncSocketException::NOT_OPEN,
"setReadCB() called with transport in bad state");
deliverError(ex);
} else {
// The read callback may have been unset earlier if our buffer was full.
startTransportReads();
}
}
}
AsyncFizzBase::QueuedWriteRequest::QueuedWriteRequest(
AsyncFizzBase* base,
folly::AsyncTransportWrapper::WriteCallback* callback,
std::unique_ptr<folly::IOBuf> data,
folly::WriteFlags flags)
: asyncFizzBase_(base), callback_(callback), flags_(flags) {
data_.append(std::move(data));
entireChainBytesBuffered = data_.chainLength();
}
void AsyncFizzBase::QueuedWriteRequest::startWriting() {
auto buf = data_.splitAtMost(kPartialWriteThreshold);
auto flags = flags_;
if (!data_.empty()) {
flags |= folly::WriteFlags::CORK;
}
size_t len = buf->computeChainDataLength();
dataWritten_ += len;
CHECK(asyncFizzBase_);
CHECK(asyncFizzBase_->tailWriteRequest_);
asyncFizzBase_->tailWriteRequest_->entireChainBytesBuffered -= len;
asyncFizzBase_->writeAppData(this, std::move(buf), flags);
}
void AsyncFizzBase::QueuedWriteRequest::append(QueuedWriteRequest* request) {
DCHECK(!next_);
next_ = request;
next_->entireChainBytesBuffered += entireChainBytesBuffered;
entireChainBytesBuffered = 0;
}
void AsyncFizzBase::QueuedWriteRequest::unlinkFromBase() {
asyncFizzBase_ = nullptr;
}
void AsyncFizzBase::QueuedWriteRequest::writeSuccess() noexcept {
if (!data_.empty()) {
startWriting();
} else {
advanceOnBase();
auto callback = callback_;
auto next = next_;
auto base = asyncFizzBase_;
delete this;
DelayedDestruction::DestructorGuard dg(base);
if (callback) {
callback->writeSuccess();
}
if (next) {
next->startWriting();
}
}
}
void AsyncFizzBase::QueuedWriteRequest::writeErr(
size_t /* written */,
const folly::AsyncSocketException& ex) noexcept {
// Deliver the error to all queued writes, starting with this one. We avoid
// recursively calling writeErr as that can cause excesssive stack usage if
// there are a large number of queued writes.
QueuedWriteRequest* errorToDeliver = this;
while (errorToDeliver) {
errorToDeliver = errorToDeliver->deliverSingleWriteErr(ex);
}
}
AsyncFizzBase::QueuedWriteRequest*
AsyncFizzBase::QueuedWriteRequest::deliverSingleWriteErr(
const folly::AsyncSocketException& ex) {
advanceOnBase();
auto callback = callback_;
auto next = next_;
auto dataWritten = dataWritten_;
delete this;
if (callback) {
callback->writeErr(dataWritten, ex);
}
return next;
}
void AsyncFizzBase::QueuedWriteRequest::advanceOnBase() {
if (!next_ && asyncFizzBase_) {
CHECK_EQ(asyncFizzBase_->tailWriteRequest_, this);
asyncFizzBase_->tailWriteRequest_ = nullptr;
}
}
void AsyncFizzBase::writeChain(
folly::AsyncTransportWrapper::WriteCallback* callback,
std::unique_ptr<folly::IOBuf>&& buf,
folly::WriteFlags flags) {
auto writeSize = buf->computeChainDataLength();
appBytesWritten_ += writeSize;
// We want to split up and queue large writes to avoid simultaneously storing
// unencrypted and encrypted large buffer in memory. We can skip this if the
// buffer is unshared (because we can encrypt in-place). We also skip this
// when sending early data to avoid the possibility of splitting writes
// between early data and normal data.
bool largeWrite = writeSize > kPartialWriteThreshold;
bool transportBuffering = transport_->getRawBytesBuffered() > 0;
bool needsToQueue = (largeWrite || transportBuffering) && buf->isShared() &&
!connecting() && isReplaySafe();
if (tailWriteRequest_ || needsToQueue) {
auto newWriteRequest =
new QueuedWriteRequest(this, callback, std::move(buf), flags);
if (tailWriteRequest_) {
tailWriteRequest_->append(newWriteRequest);
tailWriteRequest_ = newWriteRequest;
} else {
tailWriteRequest_ = newWriteRequest;
newWriteRequest->startWriting();
}
} else {
writeAppData(callback, std::move(buf), flags);
}
}
size_t AsyncFizzBase::getAppBytesWritten() const {
return appBytesWritten_;
}
size_t AsyncFizzBase::getAppBytesReceived() const {
return appBytesReceived_;
}
size_t AsyncFizzBase::getAppBytesBuffered() const {
return transport_->getAppBytesBuffered() +
(tailWriteRequest_ ? tailWriteRequest_->getEntireChainBytesBuffered()
: 0);
}
void AsyncFizzBase::startTransportReads() {
if (transportOptions_.registerEventCallback) {
transport_->setEventCallback(this);
}
transport_->setReadCB(this);
}
void AsyncFizzBase::startHandshakeTimeout(std::chrono::milliseconds timeout) {
handshakeTimeout_.scheduleTimeout(timeout);
}
void AsyncFizzBase::cancelHandshakeTimeout() {
handshakeTimeout_.cancelTimeout();
}
void AsyncFizzBase::deliverAppData(std::unique_ptr<folly::IOBuf> data) {
if (data) {
appBytesReceived_ += data->computeChainDataLength();
}
if (appDataBuf_) {
if (data) {
appDataBuf_->prependChain(std::move(data));
}
data = std::move(appDataBuf_);
}
while (readCallback_ && data) {
if (readCallback_->isBufferMovable()) {
return readCallback_->readBufferAvailable(std::move(data));
} else {
folly::io::Cursor cursor(data.get());
size_t available = 0;
while ((available = cursor.totalLength()) != 0 && readCallback_ &&
!readCallback_->isBufferMovable()) {
void* buf = nullptr;
size_t buflen = 0;
try {
readCallback_->getReadBuffer(&buf, &buflen);
} catch (const AsyncSocketException& ase) {
return deliverError(ase);
} catch (const std::exception& e) {
AsyncSocketException ase(
AsyncSocketException::BAD_ARGS,
folly::to<std::string>("getReadBuffer() threw ", e.what()));
return deliverError(ase);
} catch (...) {
AsyncSocketException ase(
AsyncSocketException::BAD_ARGS,
"getReadBuffer() threw unknown exception");
return deliverError(ase);
}
if (buflen == 0 || buf == nullptr) {
AsyncSocketException ase(
AsyncSocketException::BAD_ARGS,
"getReadBuffer() returned empty buffer");
return deliverError(ase);
}
size_t bytesToRead = std::min(buflen, available);
cursor.pull(buf, bytesToRead);
readCallback_->readDataAvailable(bytesToRead);
}
// If we have data left, it means the read callback changed and we need
// to save the remaining data (if any)
if (available != 0) {
std::unique_ptr<folly::IOBuf> remainingData;
cursor.clone(remainingData, available);
data = std::move(remainingData);
} else {
// Out of data. Reset the data pointer to end the loop
data.reset();
}
}
}
if (data) {
appDataBuf_ = std::move(data);
}
checkBufLen();
}
void AsyncFizzBase::deliverError(
const AsyncSocketException& ex,
bool closeTransport) {
DelayedDestruction::DestructorGuard dg(this);
if (readCallback_) {
auto readCallback = readCallback_;
readCallback_ = nullptr;
if (ex.getType() == AsyncSocketException::END_OF_FILE) {
readCallback->readEOF();
} else {
readCallback->readErr(ex);
}
}
// Clear the secret callback too.
if (secretCallback_) {
secretCallback_ = nullptr;
}
if (closeTransport) {
transport_->close();
}
}
class AsyncFizzBase::FizzMsgHdr : public folly::EventRecvmsgCallback::MsgHdr {
FizzMsgHdr() = delete;
public:
~FizzMsgHdr() override = default;
explicit FizzMsgHdr(AsyncFizzBase* fizzBase) {
arg_ = fizzBase;
freeFunc_ = FizzMsgHdr::free;
cbFunc_ = FizzMsgHdr::cb;
}
void reset() {
data_ = msghdr{};
auto base = static_cast<AsyncFizzBase*>(arg_);
base->getReadBuffer(&iov_.iov_base, &iov_.iov_len);
data_.msg_iov = &iov_;
data_.msg_iovlen = 1;
}
static void free(folly::EventRecvmsgCallback::MsgHdr* msgHdr) {
delete msgHdr;
}
static void cb(folly::EventRecvmsgCallback::MsgHdr* msgHdr, int res) {
static_cast<AsyncFizzBase*>(msgHdr->arg_)
->eventRecvmsgCallback(static_cast<FizzMsgHdr*>(msgHdr), res);
}
private:
iovec iov_;
};
folly::EventRecvmsgCallback::MsgHdr* AsyncFizzBase::allocateData() {
auto* ret = msgHdr_.release();
if (!ret) {
ret = new FizzMsgHdr(this);
}
ret->reset();
return ret;
}
void AsyncFizzBase::eventRecvmsgCallback(FizzMsgHdr* msgHdr, int res) {
DelayedDestruction::DestructorGuard dg(this);
if (res > 0) {
transportReadBuf_.postallocate(res);
transportDataAvailable();
checkBufLen();
} else if (res == 0) {
readEOF();
} else {
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR, "event recv failed", (0 - res));
deliverError(ex);
}
msgHdr_.reset(msgHdr);
}
void AsyncFizzBase::getReadBuffer(void** bufReturn, size_t* lenReturn) {
std::pair<void*, uint32_t> readSpace =
transportReadBuf_.preallocate(kMinReadSize, kMaxReadSize);
*bufReturn = readSpace.first;
// `readSizeHint_`, if zero, indicates that we do not care about how much
// data we read from the underlying socket.
//
// `readSizeHint_`, if nonzero, indicates the maximum amount of data we
// want to read from the underlying socket. This is necessary for kTLS,
// where we want to ensure that when ReportHandshakeSuccess is called, we
// are at a known point in the TCP stream, so we can let the kernel start
// decrypting records for us.
//
// For transport with "record aligned reads", we initially set `readSizeHint_`
// equal to the size of the TLS record header. Subsequently, the state machine
// will tell us exactly how much data is required to complete the record
// in WaitForData actions.
if (readSizeHint_ > 0) {
*lenReturn = std::min(
static_cast<decltype(readSizeHint_)>(kMinReadSize), readSizeHint_);
} else {
*lenReturn = readSpace.second;
}
}
void AsyncFizzBase::getReadBuffers(folly::IOBufIovecBuilder::IoVecVec& iovs) {
ioVecQueue_.allocateBuffers(iovs, kMaxReadSize);
}
void AsyncFizzBase::readDataAvailable(size_t len) noexcept {
DelayedDestruction::DestructorGuard dg(this);
if (getReadMode() == folly::AsyncTransport::ReadCallback::ReadMode::ReadVec) {
auto tmp = ioVecQueue_.extractIOBufChain(len);
transportReadBuf_.append(std::move(tmp));
} else {
transportReadBuf_.postallocate(len);
}
transportDataAvailable();
checkBufLen();
}
bool AsyncFizzBase::isBufferMovable() noexcept {
return true;
}
void AsyncFizzBase::readBufferAvailable(
std::unique_ptr<folly::IOBuf> data) noexcept {
DelayedDestruction::DestructorGuard dg(this);
transportReadBuf_.append(std::move(data));
transportDataAvailable();
checkBufLen();
}
void AsyncFizzBase::readEOF() noexcept {
AsyncSocketException eof(AsyncSocketException::END_OF_FILE, "readEOF()");
transportError(eof);
}
void AsyncFizzBase::readErr(const folly::AsyncSocketException& ex) noexcept {
transportError(ex);
}
void AsyncFizzBase::writeSuccess() noexcept {}
void AsyncFizzBase::writeErr(
size_t /* bytesWritten */,
const folly::AsyncSocketException& ex) noexcept {
transportError(ex);
}
void AsyncFizzBase::checkBufLen() {
if (!readCallback_ &&
(transportReadBuf_.chainLength() >= kMaxBufSize ||
(appDataBuf_ && appDataBuf_->computeChainDataLength() >= kMaxBufSize))) {
transport_->setEventCallback(nullptr);
transport_->setReadCB(nullptr);
}
}
void AsyncFizzBase::handshakeTimeoutExpired() noexcept {
AsyncSocketException eof(
AsyncSocketException::TIMED_OUT, "handshake timeout expired");
transportError(eof);
}
void AsyncFizzBase::endOfTLS(std::unique_ptr<folly::IOBuf> endOfData) noexcept {
DelayedDestruction::DestructorGuard dg(this);
if (connecting()) {
AsyncSocketException ex(
AsyncSocketException::INVALID_STATE,
"tls connection torn down while connecting");
transportError(ex);
return;
}
if (endOfTLSCallback_) {
endOfTLSCallback_->endOfTLS(this, std::move(endOfData));
} else {
// The end of TLS callback may not want the socket to be closed but by
// default read callbacks often close on EOF, as such we defer to the setter
// of the end of tls callback to apply the appropriate behaviour if it's set
if (readCallback_) {
auto readCallback = readCallback_;
readCallback_ = nullptr;
readCallback->readEOF();
}
transport_->close();
}
}
// The below maps the secret type to the appropriate secret callback function.
namespace {
class SecretVisitor {
public:
explicit SecretVisitor(
AsyncFizzBase::SecretCallback* cb,
const std::vector<uint8_t>& secretBuf)
: callback_(cb), secretBuf_(secretBuf) {}
void operator()(const SecretType& secretType) {
switch (secretType.type()) {
case SecretType::Type::EarlySecrets_E:
operator()(*secretType.asEarlySecrets());
break;
case SecretType::Type::HandshakeSecrets_E:
operator()(*secretType.asHandshakeSecrets());
break;
case SecretType::Type::MasterSecrets_E:
operator()(*secretType.asMasterSecrets());
break;
case SecretType::Type::AppTrafficSecrets_E:
operator()(*secretType.asAppTrafficSecrets());
break;
}
}
void operator()(const EarlySecrets& secret) {
switch (secret) {
case EarlySecrets::ExternalPskBinder:
callback_->externalPskBinderAvailable(secretBuf_);
return;
case EarlySecrets::ResumptionPskBinder:
callback_->resumptionPskBinderAvailable(secretBuf_);
return;
case EarlySecrets::ClientEarlyTraffic:
callback_->clientEarlyTrafficSecretAvailable(secretBuf_);
return;
case EarlySecrets::EarlyExporter:
callback_->earlyExporterSecretAvailable(secretBuf_);
return;
}
}
void operator()(const HandshakeSecrets& secret) {
switch (secret) {
case HandshakeSecrets::ClientHandshakeTraffic:
callback_->clientHandshakeTrafficSecretAvailable(secretBuf_);
return;
case HandshakeSecrets::ServerHandshakeTraffic:
callback_->serverHandshakeTrafficSecretAvailable(secretBuf_);
return;
}
}
void operator()(const MasterSecrets& secret) {
switch (secret) {
case MasterSecrets::ExporterMaster:
callback_->exporterMasterSecretAvailable(secretBuf_);
return;
case MasterSecrets::ResumptionMaster:
callback_->resumptionMasterSecretAvailable(secretBuf_);
return;
}
}
void operator()(const AppTrafficSecrets& secret) {
switch (secret) {
case AppTrafficSecrets::ClientAppTraffic:
callback_->clientAppTrafficSecretAvailable(secretBuf_);
return;
case AppTrafficSecrets::ServerAppTraffic:
callback_->serverAppTrafficSecretAvailable(secretBuf_);
return;
}
}
private:
AsyncFizzBase::SecretCallback* callback_;
const std::vector<uint8_t>& secretBuf_;
};
} // namespace
void AsyncFizzBase::secretAvailable(const DerivedSecret& secret) noexcept {
if (secretCallback_) {
SecretVisitor visitor(secretCallback_, secret.secret);
visitor(secret.type);
}
}
} // namespace fizz