fizz/protocol/AsyncFizzBase.h (237 lines of code) (raw):

/* * Copyright (c) 2018-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include <fizz/protocol/KeyScheduler.h> #include <fizz/record/Types.h> #include <folly/io/IOBufIovecBuilder.h> #include <folly/io/IOBufQueue.h> #include <folly/io/async/AsyncSocket.h> #include <folly/io/async/WriteChainAsyncTransportWrapper.h> namespace fizz { using Cert = folly::AsyncTransportCertificate; /** * This class is a wrapper around AsyncTransportWrapper to handle most app level * interactions so that the derived client and server classes */ class AsyncFizzBase : public folly::WriteChainAsyncTransportWrapper< folly::AsyncTransportWrapper>, protected folly::AsyncTransportWrapper::WriteCallback, protected folly::AsyncTransportWrapper::ReadCallback, protected folly::EventRecvmsgCallback { public: using UniquePtr = std::unique_ptr<AsyncFizzBase, folly::DelayedDestruction::Destructor>; using ReadCallback = folly::AsyncTransportWrapper::ReadCallback; class HandshakeTimeout : public folly::AsyncTimeout { public: HandshakeTimeout(AsyncFizzBase& transport, folly::EventBase* eventBase) : folly::AsyncTimeout(eventBase), transport_(transport) {} ~HandshakeTimeout() override = default; void timeoutExpired() noexcept override { transport_.handshakeTimeoutExpired(); } private: AsyncFizzBase& transport_; }; class SecretCallback { public: virtual ~SecretCallback() = default; /** * Each of the below is called when the corresponding secret is received. */ virtual void externalPskBinderAvailable( const std::vector<uint8_t>&) noexcept {} virtual void resumptionPskBinderAvailable( const std::vector<uint8_t>&) noexcept {} virtual void earlyExporterSecretAvailable( const std::vector<uint8_t>&) noexcept {} virtual void clientEarlyTrafficSecretAvailable( const std::vector<uint8_t>&) noexcept {} virtual void clientHandshakeTrafficSecretAvailable( const std::vector<uint8_t>&) noexcept {} virtual void serverHandshakeTrafficSecretAvailable( const std::vector<uint8_t>&) noexcept {} virtual void exporterMasterSecretAvailable( const std::vector<uint8_t>&) noexcept {} virtual void resumptionMasterSecretAvailable( const std::vector<uint8_t>&) noexcept {} virtual void clientAppTrafficSecretAvailable( const std::vector<uint8_t>&) noexcept {} virtual void serverAppTrafficSecretAvailable( const std::vector<uint8_t>&) noexcept {} }; class EndOfTLSCallback { public: virtual ~EndOfTLSCallback() = default; virtual void endOfTLS( AsyncFizzBase* transport, std::unique_ptr<folly::IOBuf> endOfData) = 0; }; struct TransportOptions { /** * Controls whether or not the async recv callback should be registered * (for io_uring) */ bool registerEventCallback{false}; /* * AsyncTransport read mode. * * This setting controls the strategy for reading data from the underlying * socket. * * ReadMode::ReadBuffer (default) * Under this mode, Fizz will allocate contiguous chunks of memory to * read incoming encrypted records. This might lead to higher mem usage * due to the way the memory is allocated from an IOBufQueue and also * due to the inability to do in place decryption for shared buffers * * ReadMode::ReadVec * Under this mode, Fizz will use vectored IO (`readv`) to read * incoming data. This can help avoid additional copies at the expense * of allocating extra ref counting objects. The overall mem usage is * also dependent on the readVecBlockSize value. * */ folly::AsyncReader::ReadCallback::ReadMode readMode{ReadMode::ReadBuffer}; /* * AsyncTransport read vec block size */ size_t readVecBlockSize{ folly::IOBufIovecBuilder::Options::kDefaultBlockSize}; }; explicit AsyncFizzBase( folly::AsyncTransportWrapper::UniquePtr transport, TransportOptions options); ~AsyncFizzBase() override; /** * App level information for reading/writing app data. */ ReadCallback* getReadCallback() const override; void setReadCB(ReadCallback* callback) override; void writeChain( folly::AsyncTransportWrapper::WriteCallback* callback, std::unique_ptr<folly::IOBuf>&& buf, folly::WriteFlags flags = folly::WriteFlags::NONE) override; /** * App data usage accounting. */ size_t getAppBytesWritten() const override; size_t getAppBytesReceived() const override; size_t getAppBytesBuffered() const override; /** * Information about the current transport state. * To be implemented by derived classes. */ bool good() const override = 0; bool readable() const override = 0; bool connecting() const override = 0; bool error() const override = 0; /** * Get the certificates in fizz::Cert form. */ const Cert* getPeerCertificate() const override = 0; const Cert* getSelfCertificate() const override = 0; bool isReplaySafe() const override = 0; void setReplaySafetyCallback( folly::AsyncTransport::ReplaySafetyCallback* callback) override = 0; std::string getApplicationProtocol() const noexcept override = 0; /** * Get the CipherSuite negotiated in this transport. */ virtual folly::Optional<CipherSuite> getCipher() const = 0; /** * Get the supported signature schemes in this transport. */ virtual std::vector<SignatureScheme> getSupportedSigSchemes() const = 0; /** * Get the exported material. */ Buf getExportedKeyingMaterial( folly::StringPiece label, Buf context, uint16_t length) const override = 0; /** * Clean up transport on destruction */ void destroy() override; /** * Identify the transport as Fizz. */ std::string getSecurityProtocol() const override { return "Fizz"; } /** * EventBase operations. */ void attachTimeoutManager(folly::TimeoutManager* manager) { handshakeTimeout_.attachTimeoutManager(manager); } void detachTimeoutManager() { handshakeTimeout_.detachTimeoutManager(); } void attachEventBase(folly::EventBase* eventBase) override { handshakeTimeout_.attachEventBase(eventBase); transport_->attachEventBase(eventBase); resumeEvents(); // we want to avoid setting a read cb on a bad transport (i.e. closed or // disconnected) unless we have a read callback we can pass the errors to. if (transport_->good() || readCallback_) { startTransportReads(); } } void detachEventBase() override { handshakeTimeout_.detachEventBase(); transport_->setEventCallback(nullptr); transport_->setReadCB(nullptr); transport_->detachEventBase(); pauseEvents(); } bool isDetachable() const override { return !handshakeTimeout_.isScheduled() && transport_->isDetachable(); } void setSecretCallback(SecretCallback* cb) { secretCallback_ = cb; } SecretCallback* getSecretCallback() { return secretCallback_; } // Note we clearly do not own the callback, and thus it is the caller's // responsibility to ensure the callback outlives the lifetime of // the fizz base instance. There are a couple key behavior differences if // this callback is set. // 1. We do not close the transport on receivng a close notify. It is your // responsibility to do whatever is appropriate. // 2. We do not call readEOF on any read callback set on the transport. // 3. Depending on when the tls connection is closed, there may be pending // data that exists past the close notify, this is passed along to the caller // in the endOfTLS method and the caller must decide what to do with the data virtual void setEndOfTLSCallback(EndOfTLSCallback* cb) { endOfTLSCallback_ = cb; } /** * setHandshakeRecordAlignedReads defines the behavior for reading data * from the backing transport during the handshake. * * This must be called prior to initiating the handshake. * * If true, this indicates that during the handshake, Fizz will read data * such that at the end of the handshake, the next byte in the underlying * transport's buffer (e.g. the kernel buffer) is guaranteed to be aligned * on a record boundary. * * In practice, this means that during the handshake, Fizz will read records * by (1) reading the record header and (2) reading just enough bytes to * complete the current record. This uses more system calls. * * If false, Fizz will read data from the underlying transport in chunks not * tied to any record boundary. */ void setHandshakeRecordAlignedReads(bool flag) { constexpr size_t kRecordHeaderSize = 5; if (flag) { readSizeHint_ = kRecordHeaderSize; } } /* * Gets the client random associated with this connection. The CR can be * used as a transport agnostic identifier (for instance, for NSS keylogging) */ virtual folly::Optional<Random> getClientRandom() const = 0; /* * Used to shut down the tls session, without shutting down the underlying * transport. Note you will still need to set setCloseTransportOnCloseNotify. */ virtual void tlsShutdown() = 0; /* * Sets whether or not to force in-place decryption of records. This is * usually safe to do, as long as the application can handle chained IOBufs * (as opposed to a contiguous buffer in the non-in-place case). */ void setDecryptInplace(bool inPlace) { readAeadOptions_.bufferOpt = inPlace ? Aead::BufferOption::AllowInPlace : Aead::BufferOption::RespectSharedPolicy; } /* * This sets whether or not to always perform encryption in-place, using the * same buffers passed in for writing to hold the encrypted data. The code * will do this opportunistically in certain cases (unique buffer passed in * and not split into records), but by setting this to true you can indicate * that the buffers passed in can always be used for in-place encryption * safely. This is not enabled by default (for safety). * * If you pass in unshared IOBufs for writing, you can set this to true. * Otherwise, if you have a shared buffer, its contents will be overwritten * (without throwing an error), affecting the other IOBufs sharing the * underlying buffer. Thus, in many cases it's not appropriate to set this to * true when passing in shared buffers, as the original plaintext in the * buffer will be lost. */ void setEncryptInplace(bool inPlace) { writeAeadOptions_.bufferOpt = inPlace ? Aead::BufferOption::AllowInPlace : Aead::BufferOption::RespectSharedPolicy; } protected: /** * Start reading raw data from the transport. */ virtual void startTransportReads(); /** * Interface for the derived class to schedule a handshake timeout. * * transportError() will be called if the timeout fires before it is * cancelled. */ virtual void startHandshakeTimeout(std::chrono::milliseconds); virtual void cancelHandshakeTimeout(); /** * Interfaces for the derived class to interact with the app level read * callback. */ virtual void deliverAppData(std::unique_ptr<folly::IOBuf> buf); virtual void deliverError( const folly::AsyncSocketException& ex, bool closeTransport = true); /** * Interface for the derived class to implement to receive app data from the * app layer. */ virtual void writeAppData( folly::AsyncTransportWrapper::WriteCallback* callback, std::unique_ptr<folly::IOBuf>&& buf, folly::WriteFlags flags = folly::WriteFlags::NONE) = 0; /** * Alert the derived class that a transport error occured. */ virtual void transportError(const folly::AsyncSocketException& ex) = 0; /** * Alert the derived class that additional data is available in * transportReadBuf_. */ virtual void transportDataAvailable() = 0; /** * Alert the derived class that new event processing should be paused/resumed. */ virtual void pauseEvents() = 0; virtual void resumeEvents() = 0; /** * Allows the derived class to give a derived secret to the secret callback. */ virtual void secretAvailable(const DerivedSecret& secret) noexcept; /** * Signal end of tls connection by a graceful shutdown. */ virtual void endOfTLS(std::unique_ptr<folly::IOBuf> endOfData) noexcept; /** * Called by derived classes to control the size of the next read from the * underlying transport (if using the readDataAvailable() API) when * the transport performs record aligned reads. * * Record aligned reads are not the default; it must be explicitly enabled * through AsyncFizzBase::setHandshakeRecordAlignedReads() * * setting hint=0 disables this functionality. All subsequent updateReadHint() * values will be ignored. */ void updateReadHint(size_t hint) { if (readSizeHint_ > 0) { readSizeHint_ = hint; } } folly::IOBufQueue transportReadBuf_{folly::IOBufQueue::cacheChainLength()}; Aead::AeadOptions readAeadOptions_; Aead::AeadOptions writeAeadOptions_; private: class QueuedWriteRequest : private folly::AsyncTransportWrapper::WriteCallback { public: QueuedWriteRequest( AsyncFizzBase* base, folly::AsyncTransportWrapper::WriteCallback* callback, std::unique_ptr<folly::IOBuf> data, folly::WriteFlags flags); void startWriting(); void append(QueuedWriteRequest* request); void unlinkFromBase(); size_t getEntireChainBytesBuffered() { DCHECK(!next_); return entireChainBytesBuffered; } private: void writeSuccess() noexcept override; void writeErr(size_t, const folly::AsyncSocketException&) noexcept override; QueuedWriteRequest* deliverSingleWriteErr( const folly::AsyncSocketException&); void advanceOnBase(); AsyncFizzBase* asyncFizzBase_; folly::AsyncTransportWrapper::WriteCallback* callback_; folly::IOBufQueue data_{folly::IOBufQueue::cacheChainLength()}; folly::WriteFlags flags_; size_t dataWritten_{0}; // Data length of the entire chain. Only valid at the tail node // of the chain, i.e. when next_ is null. size_t entireChainBytesBuffered; QueuedWriteRequest* next_{nullptr}; }; class FizzMsgHdr; /** * EventRecvmsgCallback implementation */ folly::EventRecvmsgCallback::MsgHdr* allocateData() override; void eventRecvmsgCallback(FizzMsgHdr* msgHdr, int res); /** * ReadCallback implementation. */ void getReadBuffer(void** bufReturn, size_t* lenReturn) override; void getReadBuffers(folly::IOBufIovecBuilder::IoVecVec& iovs) override; void readDataAvailable(size_t len) noexcept override; bool isBufferMovable() noexcept override; void readBufferAvailable( std::unique_ptr<folly::IOBuf> data) noexcept override; void readEOF() noexcept override; void readErr(const folly::AsyncSocketException& ex) noexcept override; /** * WriteCallback implementation, for use with handshake messages. */ void writeSuccess() noexcept override; void writeErr( size_t bytesWritten, const folly::AsyncSocketException& ex) noexcept override; void checkBufLen(); void handshakeTimeoutExpired() noexcept; ReadCallback* readCallback_{nullptr}; std::unique_ptr<folly::IOBuf> appDataBuf_; size_t appBytesWritten_{0}; size_t appBytesReceived_{0}; size_t readSizeHint_{0}; QueuedWriteRequest* tailWriteRequest_{nullptr}; HandshakeTimeout handshakeTimeout_; SecretCallback* secretCallback_{nullptr}; EndOfTLSCallback* endOfTLSCallback_{nullptr}; TransportOptions transportOptions_; std::unique_ptr<FizzMsgHdr> msgHdr_; folly::IOBufIovecBuilder ioVecQueue_; }; } // namespace fizz