in recipes/local_prior_match/src/runtime/DataScheduler.cpp [98:141]
void DataScheduler::update() {
++dsCurIter_[curDs_];
if (!FLAGS_noresample &&
(dsIterOffset_[curDs_] + dsCurIter_[curDs_]) % ds_[curDs_]->size() == 0) {
LOG_MASTER(INFO) << "Shuffling trainset";
ds_[curDs_]->shuffle(++dsCurEpochs_[curDs_] /* seed */);
}
if (FLAGS_schedulerorder == kInOrder) {
if (dsCurIter_[curDs_] % dsNumIters_[curDs_] == 0) {
curDs_ = (curDs_ + 1) % ds_.size();
while (dsNumIters_[curDs_] == 0) {
curDs_ = (curDs_ + 1) % ds_.size();
}
}
} else if (FLAGS_schedulerorder == kUniformOrder) {
double minVal = std::numeric_limits<double>::max();
for (int i = 0; i < ds_.size(); ++i) {
if (dsNumIters_[i] > 0) {
int offset = dsCurIter_[i] / dsNumIters_[i];
double ratio =
1.0 / (dsNumIters_[i] + 1) * (dsCurIter_[i] % dsNumIters_[i] + 1);
if (offset + ratio < minVal) {
minVal = offset + ratio;
curDs_ = i;
}
}
}
} else if (FLAGS_schedulerorder == kRandomOrder) {
for (int c = curDs_; c < dsCumNumIters_.size(); ++c) {
--dsCumNumIters_[c];
}
if (dsCumNumIters_.back() == 0) {
std::partial_sum(
dsNumIters_.begin(), dsNumIters_.end(), dsCumNumIters_.begin());
}
std::uniform_int_distribution<int> distribution(1, dsCumNumIters_.back());
auto d = distribution(gen_);
auto lit =
std::lower_bound(dsCumNumIters_.begin(), dsCumNumIters_.end(), d);
curDs_ = lit - dsCumNumIters_.begin();
}
}