SampleWeightIds sample_with_priorities_()

in csrc/liars_dice/rela/prioritized_replay.h [374:449]


  SampleWeightIds sample_with_priorities_(int batchsize,
                                          const std::string& device) {
    std::unique_lock<std::mutex> lk(mSampler_);

    float sum;
    int size = storage_.safeSize(&sum);
    // std::cout << "size: "<< size << ", sum: " << sum << std::endl;
    // storage_ [0, size) remains static in the subsequent section

    float segment = sum / batchsize;
    std::uniform_real_distribution<float> dist(0.0, segment);

    std::vector<DataType> samples;
    auto weights = torch::zeros({batchsize}, torch::kFloat32);
    auto weightAcc = weights.accessor<float, 1>();
    std::vector<int> ids(batchsize);

    double accSum = 0;
    int nextIdx = 0;
    float w = 0;
    int id = 0;
    for (int i = 0; i < batchsize; i++) {
      float rand = dist(rng_) + i * segment;
      rand = std::min(sum - (float)0.1, rand);
      // std::cout << "looking for " << i << "th/" << batchsize << " sample" <<
      // std::endl;
      // std::cout << "\ttarget: " << rand << std::endl;

      while (nextIdx <= size) {
        if ((accSum > 0 && accSum >= rand) || nextIdx == size) {
          assert(nextIdx >= 1);
          // std::cout << "\tfound: " << nextIdx - 1 << ", " << id << ", " <<
          // accSum << std::endl;
          DataType element = storage_.getElementAndMark(nextIdx - 1);
          samples.push_back(element);
          weightAcc[i] = w;
          ids[i] = id;
          break;
        }

        if (nextIdx == size) {
          // This should never happened due to the hackky if above.
          std::cout << "nextIdx: " << nextIdx << "/" << size << std::endl;
          std::cout << std::setprecision(10) << "accSum: " << accSum
                    << ", sum: " << sum << ", rand: " << rand << std::endl;
          assert(false);
        }

        w = storage_.getWeight(nextIdx, &id);
        accSum += w;
        ++nextIdx;
      }
    }
    assert((int)samples.size() == batchsize);

    // pop storage if full
    size = storage_.size();
    if (size > capacity_) {
      storage_.blockPop(size - capacity_);
    }

    // safe to unlock, because <samples> contains copys
    lk.unlock();

    weights = weights / sum;
    weights = torch::pow(size * weights, -beta_);
    weights /= weights.max();
    if (device != "cpu") {
      weights = weights.to(torch::Device(device));
    }
    auto batch = DataType::makeBatch(samples, device);
    if (compressed_values_) {
      batch.values = rela::dequantize(batch.values);
    }
    return std::make_tuple(batch, weights, ids);
  }