void DataScheduler::update()

in recipes/local_prior_match/src/runtime/DataScheduler.cpp [98:141]


void DataScheduler::update() {
  ++dsCurIter_[curDs_];

  if (!FLAGS_noresample &&
      (dsIterOffset_[curDs_] + dsCurIter_[curDs_]) % ds_[curDs_]->size() == 0) {
    LOG_MASTER(INFO) << "Shuffling trainset";
    ds_[curDs_]->shuffle(++dsCurEpochs_[curDs_] /* seed */);
  }

  if (FLAGS_schedulerorder == kInOrder) {
    if (dsCurIter_[curDs_] % dsNumIters_[curDs_] == 0) {
      curDs_ = (curDs_ + 1) % ds_.size();
      while (dsNumIters_[curDs_] == 0) {
        curDs_ = (curDs_ + 1) % ds_.size();
      }
    }
  } else if (FLAGS_schedulerorder == kUniformOrder) {
    double minVal = std::numeric_limits<double>::max();
    for (int i = 0; i < ds_.size(); ++i) {
      if (dsNumIters_[i] > 0) {
        int offset = dsCurIter_[i] / dsNumIters_[i];
        double ratio =
            1.0 / (dsNumIters_[i] + 1) * (dsCurIter_[i] % dsNumIters_[i] + 1);
        if (offset + ratio < minVal) {
          minVal = offset + ratio;
          curDs_ = i;
        }
      }
    }
  } else if (FLAGS_schedulerorder == kRandomOrder) {
    for (int c = curDs_; c < dsCumNumIters_.size(); ++c) {
      --dsCumNumIters_[c];
    }
    if (dsCumNumIters_.back() == 0) {
      std::partial_sum(
          dsNumIters_.begin(), dsNumIters_.end(), dsCumNumIters_.begin());
    }
    std::uniform_int_distribution<int> distribution(1, dsCumNumIters_.back());
    auto d = distribution(gen_);
    auto lit =
        std::lower_bound(dsCumNumIters_.begin(), dsCumNumIters_.end(), d);
    curDs_ = lit - dsCumNumIters_.begin();
  }
}