util/WdtSocket.h (108 lines of code) (raw):
#pragma once
#include <sys/socket.h>
#include <wdt/ErrorCodes.h>
#include <wdt/Protocol.h>
#include <wdt/util/CommonImpl.h>
#include <wdt/util/EncryptionUtils.h>
#include <memory>
namespace facebook {
namespace wdt {
using Func = std::function<void()>;
/// base socket class
/// Do not read/write more than 2Gb at a time (int sizes)
/// This is ok because you don't want internal blocks bigger than a few
/// Mbytes anyway.
class WdtSocket {
public:
WdtSocket(ThreadCtx &threadCtx, int port,
const EncryptionParams &encryptionParams, int64_t ivChangeInterval,
Func &&tagVerificationSuccessCallback);
// making the object non-copyable and non-movable
WdtSocket(const WdtSocket &stats) = delete;
WdtSocket &operator=(const WdtSocket &stats) = delete;
WdtSocket(WdtSocket &&stats) = delete;
WdtSocket &operator=(WdtSocket &&stats) = delete;
/// tries to read nbyte data and periodically checks for abort
int read(char *buf, int nbyte, bool tryFull = true);
/// tries to read nbyte data with a specific and periodically checks for abort
int readWithTimeout(char *buf, int nbyte, int timeoutMs, bool tryFull = true);
/// tries to write nbyte data and periodically checks for abort, if retry is
/// true, socket tries to write as long as it makes some progress within a
/// write timeout
int write(char *buf, int nbyte, bool retry = false);
/// writes the tag/mac (for gcm) and shuts down the write half of the
/// underlying socket
ErrorCode shutdownWrites();
/// expect logical and physical end of stream: read the tag and finialize
ErrorCode expectEndOfStream();
/**
* Normal closing of the current connection.
* may return ENCRYPTION_ERROR if the stream is corrupt (gcm mode)
*/
ErrorCode closeConnection();
/**
* Close unexpectedly (will not read/write the checksum).
* This api is to be avoided/elminated.
*/
void closeNoCheck();
/// @return current fd
int getFd() const;
void setFd(int fd);
/// @return port
int getPort() const;
void setPort(int port);
/// @return current encryption type
EncryptionType getEncryptionType() const;
/// @return possible non-retryable error
ErrorCode getNonRetryableErrCode() const;
/// @return read error code
ErrorCode getReadErrCode() const;
/// @return write error code
ErrorCode getWriteErrCode() const;
/// @return tcp receive buffer size
int getReceiveBufferSize() const;
/// @return tcp send buffer size
int getSendBufferSize() const;
/// @return number of unacked bytes in send buffer, returns -1 in case it
/// fails to get unacked bytes for this socket
int getUnackedBytes() const;
int64_t getNumRead() const {
return totalRead_;
}
int64_t getNumWritten() const {
return totalWritten_;
}
void disableIvChange() {
WLOG(INFO) << "Disabling periodic encryption iv change";
ivChangeInterval_ = 0;
}
/// sets read and write timeouts for the socket
void setSocketTimeouts();
// manipulates DSCP Bits
void setDscp(int dscp);
ThreadCtx& getThreadCtx() {
return threadCtx_;
}
void enableUnencryptedPeerSupport() {
supportUnencryptedPeer_ = true;
}
/**
* Returns ip and port for a socket address
*
* @param sa socket address
* @param salen socket address length
* @param host this is set to host name
* @param port this is set to port
*
* @return whether getnameinfo was successful or not
*/
static bool getNameInfo(const struct sockaddr *sa, socklen_t salen,
std::string &host, std::string &port);
virtual ~WdtSocket();
private:
void resetEncryptor();
void resetDecryptor();
/// computes effective timeout depending on the network timeout and abort
/// check interval
int getEffectiveTimeout(int networkTimeout);
/// @see ioWithAbortCheck
int64_t readWithAbortCheck(char *buf, int64_t nbyte, int timeoutMs,
bool tryFull);
/// @see ioWithAbortCheck
int64_t writeWithAbortCheck(const char *buf, int64_t nbyte, int timeoutMs,
bool tryFull);
/**
* Tries to read/write numBytes amount of data from fd. Also, checks for abort
* after every read/write call. Also, retries till the input timeout.
* Optionally, returns after first successful read/write call.
*
* @param readOrWrite read/write
* @param fd socket file descriptor
* @param tbuf buffer
* @param numBytes number of bytes to read/write
* @param abortChecker abort checker callback
* @param timeoutMs timeout in milliseconds
* @param tryFull if true, this function tries to read complete data.
* Otherwise, this function returns after the first
* successful read/write. This is set to false for
* receiver pipelining.
*
* @return in case of success number of bytes read/written, else
* returns -1
*/
template <typename F, typename T>
int64_t ioWithAbortCheck(F readOrWrite, T tbuf, int64_t numBytes,
int timeoutMs, bool tryFull);
// computes next tag offset
int computeNextTagOffset(int64_t totalProcessed, int64_t tagInterval);
// reads from socket and decrypts. Does not understand tag verification
int readAndDecrypt(char *buf, int nbyte, int timeoutMs, bool tryFull);
// reads from socket, decrypts and verifies tag. If the read contains a tag,
// first, we read till the tag and decrypt. Then, the tag(plain-text) is read
// and verified. After that remaining bytes are read.
// This method expects one tag contained in the read. So, nbyte must be
// less than readTagInterval_
int readAndDecryptWithTag(char *buf, int nbyte, int timeoutMs, bool tryFull);
// reads encryption tag. Returns empty string in case of failure.
std::string readEncryptionTag();
// checks whether decryption iv has changed or not. If yes, it reads the new
// iv
bool checkAndChangeDecryptionIv(const std::string &tag);
// reads from socket. Does not understand encryption
int readInternal(char *buf, int nbyte, int timeoutMs, bool tryFull);
// encrypts and writes. Does not understand encryption tag
int encryptAndWrite(char *buf, int nbyte, int timeoutMs, bool retry);
// encrypts, writes and also adds tag if necessary.
// This method expects one tag contained in the write. So, nbyte must be less
// than writeTagInterval_
int encryptAndWriteWithTag(char *buf, int nbyte, int timeoutMs, bool retry);
// writes encryption tag. Returns status
bool writeEncryptionTag();
// checks whether encryption iv needs to change, and if yes, changes it. This
// also sends the new iv
bool checkAndChangeEncryptionIv();
// writes to socket. Does not understand encryption
int writeInternal(const char *buf, int nbyte, int timeoutMs, bool retry);
void readEncryptionSettingsOnce(int timeoutMs);
void writeEncryptionSettingsOnce();
/// If doTagIOs is false will not try to read/write the final encryption tag
virtual ErrorCode closeConnectionInternal(bool doTagIOs);
ErrorCode finalizeWrites(bool doTagIOs);
ErrorCode finalizeReads(bool doTagIOs);
int port_{-1};
int fd_{-1};
ThreadCtx &threadCtx_;
EncryptionParams encryptionParams_;
int64_t ivChangeInterval_{0};
Func tagVerificationSuccessCallback_{nullptr};
bool encryptionSettingsWritten_{false};
bool encryptionSettingsRead_{false};
std::unique_ptr<AESEncryptor> encryptor_;
std::unique_ptr<AESDecryptor> decryptor_;
/// buffer used to encrypt/decrypt
char buf_[Protocol::kEncryptionCmdLen];
int32_t readTagInterval_{0};
int32_t writeTagInterval_{0};
int64_t totalRead_{0};
int64_t totalWritten_{0};
/// If this is true, then if we get a cmd other than encryption cmd from the
/// peer, we expect the other side to not be encryption aware. We turn off
/// encryption in that case
bool supportUnencryptedPeer_{false};
// need two error codes because a socket does bi-directional communication
ErrorCode readErrorCode_{OK};
ErrorCode writeErrorCode_{OK};
/// Have we already completed encryption and wrote the tag
bool writesFinalized_{false};
/// Have we already read the tag and completed decryption
bool readsFinalized_{false};
};
}
}