in src/rpc.cc [626:700]
bool banditSend(
uint32_t mask, Buffer buffer, Deferrer& defer, size_t* indexUsed = nullptr,
Me<RpcConnectionImplBase>* outConnection = nullptr, bool shouldFindPeer = true) noexcept {
// log("banditSend %d bytes mask %#x\n", (int)buffer->size, mask);
auto now = std::chrono::steady_clock::now();
thread_local std::vector<std::pair<size_t, float>> list;
list.clear();
float sum = 0.0f;
bool hasCudaTensor = false;
auto* tensors = buffer->tensors();
for (size_t i = 0; i != buffer->nTensors; ++i) {
if (tensors[i].tensor.is_cuda()) {
hasCudaTensor = true;
break;
}
}
for (size_t i = 0; i != connections_.size(); ++i) {
if (~mask & (1 << i)) {
continue;
}
if (hasCudaTensor) {
fatal("CUDA tensors are currently not supported, sorry!");
bool supportsCuda = false;
switchOnAPI((ConnectionType)i, [&](auto api) { supportsCuda = decltype(api)::supportsCuda; });
if (!supportsCuda) {
continue;
}
}
auto& v = connections_[i];
if (willConnectOrSend(now, v)) {
float score = std::exp(v.readBanditValue * 4);
// log("bandit %s has score %g\n", connectionTypeName[i], score);
sum += score;
list.emplace_back(i, sum);
}
}
if (list.size() > 0) {
size_t index;
if (list.size() == 1) {
index = list[0].first;
} else {
float v = std::uniform_real_distribution<float>(0.0f, sum)(rng);
index = std::lower_bound(list.begin(), std::prev(list.end()), v, [&](auto& a, float b) {
return a.second < b;
})->first;
}
// log("bandit chose %d (%s)\n", index, connectionTypeName.at(index));
auto& x = connections_.at(index);
x.sendCount.fetch_add(1, std::memory_order_relaxed);
bool b = switchOnAPI(
(ConnectionType)index, [&](auto api) { return send<decltype(api)>(now, buffer, outConnection, defer); });
if (!b && buffer) {
mask &= ~(1 << index);
return banditSend(mask, std::move(buffer), defer, indexUsed, outConnection, shouldFindPeer);
}
if (b && indexUsed) {
*indexUsed = index;
}
return b;
} else {
// log("No connectivity to %s\n", name);
if (shouldFindPeer) {
int timeout = findThisPeerIncrementingTimeoutMilliseconds;
if (now - lastFindThisPeer.load(std::memory_order_relaxed) >= std::chrono::milliseconds(timeout)) {
log("findpeer timeout is %d\n", timeout);
findThisPeerIncrementingTimeoutMilliseconds.store(
std::min(std::max(timeout, 250) * 2, 1000), std::memory_order_relaxed);
lastFindThisPeer.store(now, std::memory_order_relaxed);
findPeer();
}
}
return false;
}
}