in csrc/liars_dice/rela/prioritized_replay.h [374:449]
SampleWeightIds sample_with_priorities_(int batchsize,
const std::string& device) {
std::unique_lock<std::mutex> lk(mSampler_);
float sum;
int size = storage_.safeSize(&sum);
// std::cout << "size: "<< size << ", sum: " << sum << std::endl;
// storage_ [0, size) remains static in the subsequent section
float segment = sum / batchsize;
std::uniform_real_distribution<float> dist(0.0, segment);
std::vector<DataType> samples;
auto weights = torch::zeros({batchsize}, torch::kFloat32);
auto weightAcc = weights.accessor<float, 1>();
std::vector<int> ids(batchsize);
double accSum = 0;
int nextIdx = 0;
float w = 0;
int id = 0;
for (int i = 0; i < batchsize; i++) {
float rand = dist(rng_) + i * segment;
rand = std::min(sum - (float)0.1, rand);
// std::cout << "looking for " << i << "th/" << batchsize << " sample" <<
// std::endl;
// std::cout << "\ttarget: " << rand << std::endl;
while (nextIdx <= size) {
if ((accSum > 0 && accSum >= rand) || nextIdx == size) {
assert(nextIdx >= 1);
// std::cout << "\tfound: " << nextIdx - 1 << ", " << id << ", " <<
// accSum << std::endl;
DataType element = storage_.getElementAndMark(nextIdx - 1);
samples.push_back(element);
weightAcc[i] = w;
ids[i] = id;
break;
}
if (nextIdx == size) {
// This should never happened due to the hackky if above.
std::cout << "nextIdx: " << nextIdx << "/" << size << std::endl;
std::cout << std::setprecision(10) << "accSum: " << accSum
<< ", sum: " << sum << ", rand: " << rand << std::endl;
assert(false);
}
w = storage_.getWeight(nextIdx, &id);
accSum += w;
++nextIdx;
}
}
assert((int)samples.size() == batchsize);
// pop storage if full
size = storage_.size();
if (size > capacity_) {
storage_.blockPop(size - capacity_);
}
// safe to unlock, because <samples> contains copys
lk.unlock();
weights = weights / sum;
weights = torch::pow(size * weights, -beta_);
weights /= weights.max();
if (device != "cpu") {
weights = weights.to(torch::Device(device));
}
auto batch = DataType::makeBatch(samples, device);
if (compressed_values_) {
batch.values = rela::dequantize(batch.values);
}
return std::make_tuple(batch, weights, ids);
}