void DataScheduler::initialize()

in recipes/local_prior_match/src/runtime/DataScheduler.cpp [50:85]


void DataScheduler::initialize() {
  LOG_IF(FATAL, ds_.size() != dsNumIters_.size())
      << "mismatch between the number of datasets "
      << "and the number of schedules specified";

  dsCumNumIters_.resize(dsNumIters_.size());
  for (int i = 0; i < dsNumIters_.size(); ++i) {
    LOG_IF(FATAL, dsNumIters_[i] < 0)
        << "Invalid training schedule (number of iterations < 0)";
    if (i == 0) {
      dsCumNumIters_[i] = dsNumIters_[i];
    } else {
      dsCumNumIters_[i] = dsNumIters_[i] + dsCumNumIters_[i - 1];
    }
  }
  LOG_IF(FATAL, dsCumNumIters_.back() == 0)
      << "Invalid training schedule (zero iterations on all datasets)";

  if (FLAGS_schedulerorder == kInOrder) {
    curDs_ = 0;
    while (curDs_ < dsNumIters_.size() && dsNumIters_[curDs_] == 0) {
      ++curDs_;
    }
  } else if (FLAGS_schedulerorder == kUniformOrder) {
    curDs_ = std::max_element(dsNumIters_.begin(), dsNumIters_.end()) -
        dsNumIters_.begin();
  } else if (FLAGS_schedulerorder == kRandomOrder) {
    std::uniform_int_distribution<int> distribution(1, dsCumNumIters_.back());
    auto d = distribution(gen_);
    auto lit =
        std::lower_bound(dsCumNumIters_.begin(), dsCumNumIters_.end(), d);
    curDs_ = lit - dsCumNumIters_.begin();
  } else {
    LOG(FATAL) << "unimplemented order: " << FLAGS_schedulerorder;
  }
}