Protocol.cpp (500 lines of code) (raw):
/**
* Copyright (c) 2014-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <wdt/Protocol.h>
#include <wdt/ErrorCodes.h>
#include <wdt/WdtOptions.h>
#include <wdt/util/SerializationUtil.h>
namespace facebook {
namespace wdt {
using std::string;
using folly::ByteRange;
const int Protocol::protocol_version = WDT_PROTOCOL_VERSION;
const int Protocol::RECEIVER_PROGRESS_REPORT_VERSION = 11;
const int Protocol::CHECKSUM_VERSION = 12;
const int Protocol::DOWNLOAD_RESUMPTION_VERSION = 13;
const int Protocol::SETTINGS_FLAG_VERSION = 12;
const int Protocol::HEADER_FLAG_AND_PREV_SEQ_ID_VERSION = 13;
const int Protocol::CHECKPOINT_OFFSET_VERSION = 16;
const int Protocol::CHECKPOINT_SEQ_ID_VERSION = 21;
const int Protocol::ENCRYPTION_V1_VERSION = 23;
const int Protocol::INCREMENTAL_TAG_VERIFICATION_VERSION = 25;
const int Protocol::DELETE_CMD_VERSION = 26;
const int Protocol::VARINT_CHANGE = 27;
const int Protocol::HEART_BEAT_VERSION = 29;
const int Protocol::PERIODIC_ENCRYPTION_IV_CHANGE_VERSION = 30;
/* All methods of Protocol class are static (functions) */
const string Protocol::getFullVersion() {
string fullVersion(WDT_VERSION_STR);
fullVersion.append(" p ");
fullVersion.append(std::to_string(protocol_version));
return fullVersion;
}
int Protocol::negotiateProtocol(int requestedProtocolVersion,
int curProtocolVersion) {
if (requestedProtocolVersion < 10) {
WLOG(WARNING) << "Can not handle protocol " << requestedProtocolVersion;
return 0;
}
return std::min<int>(curProtocolVersion, requestedProtocolVersion);
}
std::ostream &operator<<(std::ostream &os, const Checkpoint &checkpoint) {
os << "checkpoint-port: " << checkpoint.port
<< " num-blocks: " << checkpoint.numBlocks
<< " seq-id: " << checkpoint.lastBlockSeqId
<< " block-offset: " << checkpoint.lastBlockOffset
<< " received-bytes: " << checkpoint.lastBlockReceivedBytes;
return os;
}
int64_t FileChunksInfo::getTotalChunkSize() const {
int64_t totalChunkSize = 0;
for (const auto &chunk : chunks_) {
totalChunkSize += chunk.size();
}
return totalChunkSize;
}
void FileChunksInfo::addChunk(const Interval &chunk) {
chunks_.emplace_back(chunk);
}
void FileChunksInfo::mergeChunks() {
if (chunks_.empty()) {
return;
}
std::sort(chunks_.begin(), chunks_.end());
std::vector<Interval> mergedChunks;
Interval curChunk = chunks_[0];
const size_t numChunks = chunks_.size();
for (size_t i = 1; i < numChunks; i++) {
if (chunks_[i].start_ > curChunk.end_) {
mergedChunks.emplace_back(curChunk);
curChunk = chunks_[i];
} else {
curChunk.end_ = std::max(curChunk.end_, chunks_[i].end_);
}
}
mergedChunks.emplace_back(curChunk);
chunks_ = mergedChunks;
}
std::vector<Interval> FileChunksInfo::getRemainingChunks(int64_t curFileSize) {
std::vector<Interval> remainingChunks;
int64_t curStart = 0;
for (const auto &chunk : chunks_) {
if (chunk.start_ > curStart) {
remainingChunks.emplace_back(curStart, chunk.start_);
}
curStart = chunk.end_;
}
if (curStart < curFileSize) {
remainingChunks.emplace_back(curStart, curFileSize);
}
return remainingChunks;
}
std::ostream &operator<<(std::ostream &os,
FileChunksInfo const &fileChunksInfo) {
os << "name " << fileChunksInfo.getFileName() << " seqId "
<< fileChunksInfo.getSeqId() << " file-size "
<< fileChunksInfo.getFileSize() << " number of chunks "
<< fileChunksInfo.getChunks().size();
for (const auto &chunk : fileChunksInfo.getChunks()) {
os << " (" << chunk.start_ << ", " << chunk.end_ << ") ";
}
return os;
}
int Protocol::getMaxLocalCheckpointLength(int protocolVersion) {
// add 10 for the size of the vector(local checkpoint is a vector of
// checkpoints with size 1). Even though, it only takes 1 byte to encode this,
// previous version of code assumes this to be 10. So, keeping this as 10
int length = 10;
// port & number of blocks
length += 2 * 10;
if (protocolVersion >= CHECKPOINT_OFFSET_VERSION) {
// number of bytes in the last block
length += 10;
}
if (protocolVersion >= CHECKPOINT_SEQ_ID_VERSION) {
// seq-id and block offset
length += 2 * 10;
}
return length;
}
bool Protocol::encodeHeader(int senderProtocolVersion, char *dest, int64_t &off,
const int64_t max,
const BlockDetails &blockDetails) {
WDT_CHECK_GE(max, 0);
const size_t umax = static_cast<size_t>(max); // we made sure it's not < 0
bool ok = encodeString(dest, max, off, blockDetails.fileName) &&
encodeVarI64C(dest, umax, off, blockDetails.seqId) &&
encodeVarI64C(dest, umax, off, blockDetails.dataSize) &&
encodeVarI64C(dest, umax, off, blockDetails.offset) &&
encodeVarI64C(dest, umax, off, blockDetails.fileSize);
if (ok && senderProtocolVersion >= HEADER_FLAG_AND_PREV_SEQ_ID_VERSION) {
uint8_t flags = blockDetails.allocationStatus;
if (off >= max) {
ok = false;
} else {
dest[off++] = static_cast<char>(flags);
if (flags == EXISTS_TOO_SMALL || flags == EXISTS_TOO_LARGE) {
// prev seq-id is only used in case the size is less on the sender side
ok = encodeVarI64C(dest, umax, off, blockDetails.prevSeqId);
}
}
}
if (!ok) {
WLOG(ERROR) << "Failed to encode header, ran out of space, " << off << " "
<< max;
}
return ok;
}
bool Protocol::decodeHeader(int receiverProtocolVersion, char *src,
int64_t &off, const int64_t max,
BlockDetails &blockDetails) {
ByteRange br = makeByteRange(src, max, off); // will check for off>0 max>0
const ByteRange obr = br;
bool ok = decodeString(br, blockDetails.fileName) &&
decodeInt64C(br, blockDetails.seqId) &&
decodeInt64C(br, blockDetails.dataSize) &&
decodeInt64C(br, blockDetails.offset) &&
decodeInt64C(br, blockDetails.fileSize);
if (ok && receiverProtocolVersion >= HEADER_FLAG_AND_PREV_SEQ_ID_VERSION) {
if (br.empty()) {
WLOG(ERROR) << "Invalid (too short) input len " << max << " at offset "
<< (max - obr.size());
return false;
}
uint8_t flags = br.front();
// first 3 bits are used to represent allocation status
blockDetails.allocationStatus = (FileAllocationStatus)(flags & 7);
br.pop_front();
if (blockDetails.allocationStatus == EXISTS_TOO_SMALL ||
blockDetails.allocationStatus == EXISTS_TOO_LARGE) {
ok = decodeInt64C(br, blockDetails.prevSeqId);
}
}
off += offset(br, obr);
return ok;
}
bool Protocol::encodeCheckpoints(int protocolVersion, char *dest, int64_t &off,
int64_t max,
const std::vector<Checkpoint> &checkpoints) {
WDT_CHECK_GE(max, 0);
const size_t umax = static_cast<size_t>(max);
bool ok = encodeVarU64(dest, umax, off, checkpoints.size());
for (const auto &checkpoint : checkpoints) {
if (!ok) {
break;
}
ok = encodeVarI64C(dest, umax, off, checkpoint.port) &&
encodeVarI64C(dest, umax, off, checkpoint.numBlocks);
if (ok && protocolVersion >= CHECKPOINT_OFFSET_VERSION) {
ok = encodeVarI64C(dest, umax, off, checkpoint.lastBlockReceivedBytes);
}
if (ok && protocolVersion >= CHECKPOINT_SEQ_ID_VERSION) {
ok = encodeVarI64C(dest, umax, off, checkpoint.lastBlockSeqId) &&
encodeVarI64C(dest, umax, off, checkpoint.lastBlockOffset);
}
}
if (!ok) {
WLOG(ERROR) << "encodeCheckpoints " << off << " " << max;
}
return ok;
}
bool Protocol::decodeCheckpoints(int protocolVersion, char *src, int64_t &off,
int64_t max,
std::vector<Checkpoint> &checkpoints) {
ByteRange br = makeByteRange(src, max, off); // will check for off>0 max>0
const ByteRange obr = br;
uint64_t len;
bool ok = decodeUInt64(br, len);
for (uint64_t i = 0; ok && i < len; i++) {
Checkpoint checkpoint;
ok = decodeInt32C(br, checkpoint.port) &&
decodeInt64C(br, checkpoint.numBlocks);
if (ok && protocolVersion >= CHECKPOINT_OFFSET_VERSION) {
ok = decodeInt64C(br, checkpoint.lastBlockReceivedBytes);
}
if (ok && protocolVersion >= CHECKPOINT_SEQ_ID_VERSION) {
// Deal with -1 encoded by pre 1.27 version
uint64_t uv = 0;
ok = decodeUInt64(br, uv);
checkpoint.lastBlockSeqId = static_cast<int64_t>(uv);
if (ok && protocolVersion < VARINT_CHANGE) {
// pre 1.27 encodes -1 for invalid and use 9 0xff bytes and 1 0x01 byte
// 1.27+ decodes the 9 0xff as max uint64_t (all FFs) so we check and
// consume the leftover 0x01
if (uv == 0xffffffffffffffff && !br.empty()) {
if (br.front() != 0x01) {
WLOG(ERROR) << "Unexpected decoding of pre1.27 -1 : " << br.front();
ok = false;
}
br.advance(1); // 1.26 used 10 bytes
WLOG(INFO) << "Fixed v" << protocolVersion << " chkpt to -1 seqid";
checkpoint.lastBlockSeqId = -1;
}
}
ok = ok && decodeInt64C(br, checkpoint.lastBlockOffset);
checkpoint.hasSeqId = true;
}
if (ok) {
checkpoints.emplace_back(checkpoint);
}
}
off += offset(br, obr);
return ok;
}
bool Protocol::encodeDone(int protocolVersion, char *dest, int64_t &off,
int64_t max, int64_t numBlocks, int64_t bytesSent) {
bool ok = encodeVarI64C(dest, max, off, numBlocks);
if (ok && protocolVersion >= CHECKPOINT_OFFSET_VERSION) {
ok = encodeVarI64C(dest, max, off, bytesSent);
}
return ok;
}
bool Protocol::decodeDone(int protocolVersion, char *src, int64_t &off,
int64_t max, int64_t &numBlocks, int64_t &bytesSent) {
ByteRange br = makeByteRange(src, max, off); // will check for off>0 max>0
const ByteRange obr = br;
bool ok = decodeInt64C(br, numBlocks);
if (ok && protocolVersion >= CHECKPOINT_OFFSET_VERSION) {
ok = decodeInt64C(br, bytesSent);
}
off += offset(br, obr);
return ok;
}
bool Protocol::encodeSize(char *dest, int64_t &off, int64_t max,
int64_t totalNumBytes) {
return encodeVarI64C(dest, max, off, totalNumBytes);
}
bool Protocol::decodeSize(char *src, int64_t &off, int64_t max,
int64_t &totalNumBytes) {
ByteRange br = makeByteRange(src, max, off); // will check for off>0 max>0
const ByteRange obr = br;
bool ok = decodeInt64C(br, totalNumBytes);
off += offset(br, obr);
return ok;
}
bool Protocol::encodeAbort(char *dest, int64_t &off, const int64_t max,
int32_t protocolVersion, ErrorCode errCode,
int64_t checkpoint) {
if (off + kAbortLength > max) {
WLOG(ERROR) << "Trying to encode abort in too small of a buffer sz " << max
<< " off " << off;
return false;
}
bool ok = encodeInt32FixedLength(dest, max, off, protocolVersion);
if (!ok) {
return false;
}
dest[off++] = errCode;
return encodeInt64FixedLength(dest, max, off, checkpoint);
}
bool Protocol::decodeAbort(char *src, int64_t &off, int64_t max,
int32_t &protocolVersion, ErrorCode &errCode,
int64_t &checkpoint) {
if (off + kAbortLength > max) {
WLOG(ERROR) << "Trying to decode abort, not enough to read sz " << max
<< " at off " << off;
return false;
}
ByteRange br = makeByteRange(src, max, off); // will check for off>0 max>0
const ByteRange obr = br;
bool ok = decodeInt32FixedLength(br, protocolVersion);
if (!ok) {
return false;
}
errCode = (ErrorCode)br.front();
br.pop_front();
ok = decodeInt64FixedLength(br, checkpoint);
off += offset(br, obr);
return ok;
}
bool Protocol::encodeChunksCmd(char *dest, int64_t &off, int64_t max,
int64_t bufSize, int64_t numFiles) {
return encodeInt64FixedLength(dest, max, off, bufSize) &&
encodeInt64FixedLength(dest, max, off, numFiles);
}
bool Protocol::decodeChunksCmd(char *src, int64_t &off, int64_t max,
int64_t &bufSize, int64_t &numFiles) {
ByteRange br = makeByteRange(src, max, off); // will check for off>0 max>0
const ByteRange obr = br;
bool ok = decodeInt64FixedLength(br, bufSize) &&
decodeInt64FixedLength(br, numFiles);
off += offset(br, obr);
return ok;
}
bool Protocol::encodeChunkInfo(char *dest, int64_t &off, int64_t max,
const Interval &chunk) {
return encodeVarI64C(dest, max, off, chunk.start_) &&
encodeVarI64C(dest, max, off, chunk.end_);
}
bool Protocol::decodeChunkInfo(ByteRange &br, Interval &chunk) {
return decodeInt64C(br, chunk.start_) && decodeInt64C(br, chunk.end_);
}
bool Protocol::encodeFileChunksInfo(char *dest, int64_t &off, int64_t max,
const FileChunksInfo &fileChunksInfo) {
bool ok = encodeVarI64C(dest, max, off, fileChunksInfo.getSeqId()) &&
encodeString(dest, max, off, fileChunksInfo.getFileName()) &&
encodeVarI64C(dest, max, off, fileChunksInfo.getFileSize()) &&
encodeVarI64C(dest, max, off, fileChunksInfo.getChunks().size());
if (!ok) {
return false;
}
for (const auto &chunk : fileChunksInfo.getChunks()) {
if (!encodeChunkInfo(dest, off, max, chunk)) {
return false;
}
}
return true;
}
bool Protocol::decodeFileChunksInfo(ByteRange &br,
FileChunksInfo &fileChunksInfo) {
int64_t seqId, fileSize, numChunks;
string fileName;
bool ok = decodeInt64C(br, seqId) && decodeString(br, fileName) &&
decodeInt64C(br, fileSize) && decodeInt64C(br, numChunks);
if (!ok) {
return false;
}
fileChunksInfo.setSeqId(seqId);
fileChunksInfo.setFileName(fileName);
fileChunksInfo.setFileSize(fileSize);
if (numChunks < 0) {
WLOG(ERROR) << "Negative number of chunks decoded " << numChunks;
return false;
}
for (int64_t i = 0; i < numChunks; i++) {
Interval chunk;
if (!decodeChunkInfo(br, chunk)) {
return false;
}
fileChunksInfo.addChunk(chunk);
}
return true;
}
int64_t Protocol::maxEncodeLen(const FileChunksInfo &fileChunkInfo) {
return 10 + 2 + fileChunkInfo.getFileName().size() + 10 + 10 +
fileChunkInfo.getChunks().size() * kMaxChunkEncodeLen;
}
int64_t Protocol::encodeFileChunksInfoList(
char *dest, int64_t &off, int64_t bufSize, int64_t startIndex,
const std::vector<FileChunksInfo> &fileChunksInfoList) {
int64_t oldOffset = off;
int64_t numEncoded = 0;
const int64_t numFileChunks = fileChunksInfoList.size();
for (int64_t i = startIndex; i < numFileChunks; i++) {
const FileChunksInfo &fileChunksInfo = fileChunksInfoList[i];
int64_t maxLength = maxEncodeLen(fileChunksInfo);
if (maxLength + oldOffset > bufSize) {
WLOG(WARNING) << "Chunk info for " << fileChunksInfo.getFileName()
<< " can not be encoded in a buffer of size " << bufSize
<< ", Ignoring.";
continue;
}
if (maxLength + off >= bufSize) {
break;
}
encodeFileChunksInfo(dest, off, bufSize, fileChunksInfo);
numEncoded++;
}
return numEncoded;
}
bool Protocol::decodeFileChunksInfoList(
char *src, int64_t &off, int64_t dataSize,
std::vector<FileChunksInfo> &fileChunksInfoList) {
ByteRange br = makeByteRange(src, dataSize, off);
const ByteRange obr = br;
while (!br.empty()) {
FileChunksInfo fileChunkInfo;
if (!decodeFileChunksInfo(br, fileChunkInfo)) {
return false;
}
fileChunksInfoList.emplace_back(std::move(fileChunkInfo));
}
off += offset(br, obr);
return true;
}
bool Protocol::encodeSettings(int senderProtocolVersion, char *dest,
int64_t &off, int64_t max,
const Settings &settings) {
bool ok = encodeVarI64C(dest, max, off, senderProtocolVersion) &&
encodeVarI64C(dest, max, off, settings.readTimeoutMillis) &&
encodeVarI64C(dest, max, off, settings.writeTimeoutMillis) &&
encodeString(dest, max, off, settings.transferId);
if (ok && senderProtocolVersion >= SETTINGS_FLAG_VERSION) {
uint8_t flags = 0;
if (settings.enableChecksum) {
flags |= 1;
}
if (settings.sendFileChunks) {
flags |= (1 << 1);
}
if (settings.blockModeDisabled) {
flags |= (1 << 2);
}
if (settings.enableHeartBeat) {
flags |= (1 << 3);
}
if (off >= max) {
return false;
}
dest[off++] = flags;
}
return ok;
}
bool Protocol::decodeVersion(char *src, int64_t &off, int64_t max,
int &senderProtocolVersion) {
ByteRange br = makeByteRange(src, max, off);
const ByteRange obr = br;
bool ok = decodeInt32C(br, senderProtocolVersion);
off += offset(br, obr);
return ok;
}
bool Protocol::decodeSettings(int protocolVersion, char *src, int64_t &off,
int64_t max, Settings &settings) {
settings.enableChecksum = settings.sendFileChunks = false;
if (off < 0) {
WLOG(ERROR) << "Invalid negative start offset for decodeSettings " << off;
return false;
}
if (off >= max) {
WLOG(ERROR) << "Invalid start offset at the end for decodeSettings " << off;
return false;
}
ByteRange br = makeByteRange(src, max, off);
const ByteRange obr = br;
bool ok = decodeInt32C(br, settings.readTimeoutMillis) &&
decodeInt32C(br, settings.writeTimeoutMillis) &&
decodeString(br, settings.transferId);
if (ok && protocolVersion >= SETTINGS_FLAG_VERSION) {
if (br.empty()) {
return false;
}
uint8_t flags = br.front();
settings.enableChecksum = flags & 1;
settings.sendFileChunks = flags & (1 << 1);
settings.blockModeDisabled = flags & (1 << 2);
settings.enableHeartBeat = flags & (1 << 3);
br.pop_front();
}
off += offset(br, obr);
return ok;
}
/* static */
bool Protocol::encodeEncryptionSettings(char *dest, int64_t &off, int64_t max,
const EncryptionType encryptionType,
const string &iv,
const int32_t tagInterval) {
return encodeVarI64C(dest, max, off, encryptionType) &&
encodeString(dest, max, off, iv) &&
encodeInt32FixedLength(dest, max, off, tagInterval);
}
/* static */
bool Protocol::decodeEncryptionSettings(char *src, int64_t &off, int64_t max,
EncryptionType &encryptionType,
string &iv, int32_t &tagInterval) {
ByteRange br = makeByteRange(src, max, off);
const ByteRange obr = br;
int64_t v;
bool ok = decodeInt64C(br, v) && decodeString(br, iv) &&
decodeInt32FixedLength(br, tagInterval);
if (ok) {
encryptionType = static_cast<EncryptionType>(v);
}
off += offset(br, obr);
return ok;
}
bool Protocol::encodeFooter(char *dest, int64_t &off, int64_t max,
int32_t checksum) {
return encodeVarI64(dest, max, off, checksum);
}
bool Protocol::decodeFooter(char *src, int64_t &off, int64_t max,
int32_t &checksum) {
ByteRange br = makeByteRange(src, max, off);
const ByteRange obr = br;
bool ok = decodeInt32(br, checksum);
off += offset(br, obr);
return ok;
}
}
}