in recipes/local_prior_match/src/runtime/DataScheduler.cpp [50:85]
void DataScheduler::initialize() {
LOG_IF(FATAL, ds_.size() != dsNumIters_.size())
<< "mismatch between the number of datasets "
<< "and the number of schedules specified";
dsCumNumIters_.resize(dsNumIters_.size());
for (int i = 0; i < dsNumIters_.size(); ++i) {
LOG_IF(FATAL, dsNumIters_[i] < 0)
<< "Invalid training schedule (number of iterations < 0)";
if (i == 0) {
dsCumNumIters_[i] = dsNumIters_[i];
} else {
dsCumNumIters_[i] = dsNumIters_[i] + dsCumNumIters_[i - 1];
}
}
LOG_IF(FATAL, dsCumNumIters_.back() == 0)
<< "Invalid training schedule (zero iterations on all datasets)";
if (FLAGS_schedulerorder == kInOrder) {
curDs_ = 0;
while (curDs_ < dsNumIters_.size() && dsNumIters_[curDs_] == 0) {
++curDs_;
}
} else if (FLAGS_schedulerorder == kUniformOrder) {
curDs_ = std::max_element(dsNumIters_.begin(), dsNumIters_.end()) -
dsNumIters_.begin();
} else if (FLAGS_schedulerorder == kRandomOrder) {
std::uniform_int_distribution<int> distribution(1, dsCumNumIters_.back());
auto d = distribution(gen_);
auto lit =
std::lower_bound(dsCumNumIters_.begin(), dsCumNumIters_.end(), d);
curDs_ = lit - dsCumNumIters_.begin();
} else {
LOG(FATAL) << "unimplemented order: " << FLAGS_schedulerorder;
}
}