void CPUWeightedSampler::WeightedSample()

in graphlearn_torch/csrc/cpu/weighted_sampler.cc [167:191]


void CPUWeightedSampler::WeightedSample(const int64_t* col_begin,
                                        const int64_t* col_end,
                                        const int64_t* eid_begin,
                                        const int64_t* eid_end,
                                        const int32_t req_num,
                                        const float* prob_begin,
                                        const float* prob_end, 
                                        int64_t* out_nbrs,
                                        int64_t* out_eid) {
  // with replacement
  const auto cap = col_end - col_begin;
  if (req_num < cap) {
    uint32_t seed = RandomSeedManager::getInstance().getSeed();
    thread_local static std::mt19937 engine(seed);
    std::discrete_distribution<> dist(prob_begin, prob_end);
    for (int32_t i = 0; i < req_num; ++i) {
      auto idx = dist(engine);
      out_nbrs[i] = col_begin[idx];
      out_eid[i] = eid_begin[idx];
    }
  } else {
    std::copy(col_begin, col_end, out_nbrs);
    std::copy(eid_begin, eid_end, out_eid);
  }
}