bool banditSend()

in src/rpc.cc [626:700]


  bool banditSend(
      uint32_t mask, Buffer buffer, Deferrer& defer, size_t* indexUsed = nullptr,
      Me<RpcConnectionImplBase>* outConnection = nullptr, bool shouldFindPeer = true) noexcept {
    // log("banditSend %d bytes mask %#x\n", (int)buffer->size, mask);
    auto now = std::chrono::steady_clock::now();
    thread_local std::vector<std::pair<size_t, float>> list;
    list.clear();
    float sum = 0.0f;
    bool hasCudaTensor = false;
    auto* tensors = buffer->tensors();
    for (size_t i = 0; i != buffer->nTensors; ++i) {
      if (tensors[i].tensor.is_cuda()) {
        hasCudaTensor = true;
        break;
      }
    }
    for (size_t i = 0; i != connections_.size(); ++i) {
      if (~mask & (1 << i)) {
        continue;
      }
      if (hasCudaTensor) {
        fatal("CUDA tensors are currently not supported, sorry!");
        bool supportsCuda = false;
        switchOnAPI((ConnectionType)i, [&](auto api) { supportsCuda = decltype(api)::supportsCuda; });
        if (!supportsCuda) {
          continue;
        }
      }
      auto& v = connections_[i];
      if (willConnectOrSend(now, v)) {
        float score = std::exp(v.readBanditValue * 4);
        // log("bandit %s has score %g\n", connectionTypeName[i], score);
        sum += score;
        list.emplace_back(i, sum);
      }
    }
    if (list.size() > 0) {
      size_t index;
      if (list.size() == 1) {
        index = list[0].first;
      } else {
        float v = std::uniform_real_distribution<float>(0.0f, sum)(rng);
        index = std::lower_bound(list.begin(), std::prev(list.end()), v, [&](auto& a, float b) {
                  return a.second < b;
                })->first;
      }
      // log("bandit chose %d (%s)\n", index, connectionTypeName.at(index));
      auto& x = connections_.at(index);
      x.sendCount.fetch_add(1, std::memory_order_relaxed);
      bool b = switchOnAPI(
          (ConnectionType)index, [&](auto api) { return send<decltype(api)>(now, buffer, outConnection, defer); });
      if (!b && buffer) {
        mask &= ~(1 << index);
        return banditSend(mask, std::move(buffer), defer, indexUsed, outConnection, shouldFindPeer);
      }
      if (b && indexUsed) {
        *indexUsed = index;
      }
      return b;
    } else {
      // log("No connectivity to %s\n", name);

      if (shouldFindPeer) {
        int timeout = findThisPeerIncrementingTimeoutMilliseconds;
        if (now - lastFindThisPeer.load(std::memory_order_relaxed) >= std::chrono::milliseconds(timeout)) {
          log("findpeer timeout is %d\n", timeout);
          findThisPeerIncrementingTimeoutMilliseconds.store(
              std::min(std::max(timeout, 250) * 2, 1000), std::memory_order_relaxed);
          lastFindThisPeer.store(now, std::memory_order_relaxed);
          findPeer();
        }
      }
      return false;
    }
  }