WdtBase.cpp (161 lines of code) (raw):

/** * Copyright (c) 2014-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include <wdt/WdtBase.h> #include <wdt/WdtTransferRequest.h> #include <random> using namespace std; namespace facebook { namespace wdt { // TODO: force callers to pass options in WdtBase::WdtBase() : abortCheckerCallback_(this) { options_.copyInto(WdtOptions::get()); } void WdtBase::setWdtOptions(const WdtOptions& src) { options_.copyInto(src); } WdtBase::~WdtBase() { abortChecker_ = nullptr; delete threadsController_; } void WdtBase::abort(const ErrorCode abortCode) { folly::RWSpinLock::WriteHolder guard(abortCodeLock_); if (abortCode == VERSION_MISMATCH && abortCode_ != OK) { // VERSION_MISMATCH is the lowest priority abort code. If the abort code is // anything other than OK, we should not override it return; } WLOG(WARNING) << "Setting the abort code " << abortCode; abortCode_ = abortCode; } void WdtBase::clearAbort() { folly::RWSpinLock::WriteHolder guard(abortCodeLock_); if (abortCode_ != VERSION_MISMATCH) { // We do no clear abort code unless it is VERSION_MISMATCH return; } WLOG(WARNING) << "Clearing the abort code"; abortCode_ = OK; } void WdtBase::setAbortChecker(const std::shared_ptr<IAbortChecker>& checker) { abortChecker_ = checker; } ErrorCode WdtBase::getCurAbortCode() const { // external check, if any: if (abortChecker_ && abortChecker_->shouldAbort()) { return ABORTED_BY_APPLICATION; } folly::RWSpinLock::ReadHolder guard(abortCodeLock_); // internal check: return abortCode_; } void WdtBase::setProgressReporter( std::unique_ptr<ProgressReporter>& progressReporter) { progressReporter_ = std::move(progressReporter); } void WdtBase::setThrottler(std::shared_ptr<Throttler> throttler) { WVLOG(2) << "Setting an external throttler"; throttler_ = throttler; } std::shared_ptr<Throttler> WdtBase::getThrottler() const { return throttler_; } void WdtBase::setTransferId(const std::string& transferId) { transferRequest_.transferId = transferId; WLOG(INFO) << "Setting transfer id " << transferId; } void WdtBase::negotiateProtocol() { int protocol = transferRequest_.protocolVersion; WDT_CHECK(protocol > 0) << "Protocol version can't be <= 0 " << protocol; int negotiatedPv = Protocol::negotiateProtocol(protocol); if (negotiatedPv != protocol) { WLOG(WARNING) << "Negotiated protocol version " << protocol << " -> " << negotiatedPv; } transferRequest_.protocolVersion = negotiatedPv; WLOG(INFO) << "using wdt protocol version " << transferRequest_.protocolVersion; } int WdtBase::getProtocolVersion() const { return transferRequest_.protocolVersion; } void WdtBase::setProtocolVersion(int protocolVersion) { transferRequest_.protocolVersion = protocolVersion; } std::string WdtBase::getTransferId() { return transferRequest_.transferId; } const std::string& WdtBase::getDirectory() const { return transferRequest_.directory; } WdtTransferRequest& WdtBase::getTransferRequest() { return transferRequest_; } void WdtBase::checkAndUpdateBufferSize() { int64_t bufSize = options_.buffer_size; if (bufSize < Protocol::kMaxHeader) { bufSize = Protocol::kMaxHeader; WLOG(WARNING) << "Specified buffer size " << options_.buffer_size << " less than " << Protocol::kMaxHeader << ", using " << bufSize; } if (bufSize % kDiskBlockSize != 0) { int64_t alignedBufSize = ((bufSize + kDiskBlockSize - 1) / kDiskBlockSize) * kDiskBlockSize; WLOG(WARNING) << "Buffer size " << bufSize << " not divisible by disk block size " << kDiskBlockSize << ", changing it to " << alignedBufSize; bufSize = alignedBufSize; } options_.buffer_size = bufSize; } WdtBase::TransferStatus WdtBase::getTransferStatus() { std::lock_guard<std::mutex> lock(mutex_); return transferStatus_; } ErrorCode WdtBase::validateTransferRequest() { ErrorCode code = transferRequest_.errorCode; if (code != OK) { WLOG(ERROR) << "WDT object initiated with erroneous transfer request " << transferRequest_.getLogSafeString(); return code; } if (transferRequest_.directory.empty() || (transferRequest_.protocolVersion < 0) || transferRequest_.ports.empty()) { WLOG(ERROR) << "Transfer request validation failed for wdt object " << transferRequest_.getLogSafeString(); code = INVALID_REQUEST; transferRequest_.errorCode = code; } return code; } void WdtBase::setTransferStatus(TransferStatus transferStatus) { std::lock_guard<std::mutex> lock(mutex_); transferStatus_ = transferStatus; if (transferStatus_ == THREADS_JOINED) { conditionFinished_.notify_one(); } } bool WdtBase::isStale() { TransferStatus status = getTransferStatus(); return (status == FINISHED || status == THREADS_JOINED); } bool WdtBase::hasStarted() { TransferStatus status = getTransferStatus(); return (status != NOT_STARTED); } void WdtBase::configureThrottler() { WDT_CHECK(!throttler_); WVLOG(1) << "Configuring throttler options"; throttler_ = std::make_shared<Throttler>(options_.getThrottlerOptions()); if (throttler_) { WLOG(INFO) << "Enabling throttling " << *throttler_; } else { WLOG(INFO) << "Throttling not enabled"; } } string WdtBase::generateTransferId() { static std::default_random_engine randomEngine{std::random_device()()}; static std::mutex mutex; string transferId; { std::lock_guard<std::mutex> lock(mutex); transferId = to_string(randomEngine()); } WVLOG(1) << "Generated a transfer id " << transferId; return transferId; } } } // namespace facebook::wdt