in graphlearn_torch/csrc/cpu/random_negative_sampler.cc [24:74]
std::tuple<torch::Tensor, torch::Tensor> CPURandomNegativeSampler::Sample(
int32_t req_num, int32_t trials_num, bool padding) {
const int64_t* row_ptr = graph_->GetRowPtr();
const int64_t* col_idx = graph_->GetColIdx();
int64_t row_num = graph_->GetRowCount();
int64_t col_num = graph_->GetColCount();
uint32_t seed = RandomSeedManager::getInstance().getSeed();
thread_local static std::mt19937 engine(seed);
std::uniform_int_distribution<int64_t> row_dist(0, row_num - 1);
std::uniform_int_distribution<int64_t> col_dist(0, col_num - 1);
int64_t row_data[req_num];
int64_t col_data[req_num];
int32_t out_prefix[req_num];
std::fill(out_prefix, out_prefix + req_num, 0);
at::parallel_for(0, req_num, 1, [&](int32_t start, int32_t end) {
for(int32_t i = start; i < end; ++i) {
for(int32_t j = 0; j < trials_num; ++j) {
int64_t r = row_dist(engine);
int64_t c = col_dist(engine);
if (!EdgeInCSR(row_ptr, col_idx, r, c)) {
row_data[i] = r;
col_data[i] = c;
out_prefix[i] = 1;
break;
}
}
}
});
// sort sampled results.
int32_t cursor = 0;
for (int32_t i = 0; i < req_num; ++i) {
if (out_prefix[i] == 1) {
row_data[cursor] = row_data[i];
col_data[cursor] = col_data[i];
++cursor;
}
}
int32_t sampled_num = std::accumulate(out_prefix, out_prefix + req_num, 0);
while ((sampled_num < req_num) && padding) { // non-strict negative sampling.
row_data[sampled_num] = row_dist(engine);
col_data[sampled_num] = col_dist(engine);
++sampled_num;
}
torch::Tensor rows = torch::empty(sampled_num, torch::kInt64);
torch::Tensor cols = torch::empty(sampled_num, torch::kInt64);
std::copy(row_data, row_data + sampled_num, rows.data_ptr<int64_t>());
std::copy(col_data, col_data + sampled_num, cols.data_ptr<int64_t>());
return std::make_tuple(rows, cols);
}