torch::Tensor randperm()

in torchbiggraph/util.cpp [14:102]


torch::Tensor randperm(long numItems, int numThreads, int64_t seedIn = -1) {
  // workaround a breaking chang in the name of CPUGenerator in PyTorch 1.5
  // https://github.com/pytorch/pytorch/pull/36027
  // This code will pick whichever class exists
  typedef std::conditional< // NOLINT
      std::is_constructible<at::CPUGeneratorImpl, uint64_t>::value,
      at::CPUGeneratorImpl,
      at::CPUGenerator>::type CPUGeneratorType;

  auto perm = torch::empty(numItems, torch::kInt64);
  auto permAccessor = perm.accessor<int64_t, 1>();
  assert(numThreads < 256);
  torch::Tensor chunks = torch::empty({numItems}, torch::kUInt8);
  auto chunksAccessor = chunks.accessor<uint8_t, 1>();
  std::vector<std::vector<int>> allCounts(numThreads);
  auto stepOne = [&](int64_t startIdx, int64_t endIdx, int threadIdx) {
    CPUGeneratorType generator(
        seedIn >= 0 ? seedIn + threadIdx : at::default_rng_seed_val);

    std::vector<int>& myCounts = allCounts[threadIdx];
    myCounts.assign(numThreads, 0);
    for (int idx = startIdx; idx < endIdx; idx += 1) {
      chunksAccessor[idx] = generator.random() % numThreads;
      myCounts[chunksAccessor[idx]] += 1;
    }
  };
  std::vector<std::thread> stepOneThreads;
  for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
    stepOneThreads.emplace_back(
        stepOne,
        threadIdx * numItems / numThreads,
        (threadIdx + 1) * numItems / numThreads,
        threadIdx);
  }
  for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
    stepOneThreads[threadIdx].join();
  }
  std::vector<std::vector<int>> allOffsets(numThreads);
  for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
    allOffsets[threadIdx].reserve(numThreads);
  }
  int64_t offset = 0;
  for (int chunkIdx = 0; chunkIdx < numThreads; chunkIdx += 1) {
    for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
      allOffsets[threadIdx].push_back(offset);
      offset += allCounts[threadIdx][chunkIdx];
    }
  }
  assert(offset == numItems);
  auto stepTwo = [&](int64_t startIdx, int64_t endIdx, int threadIdx) {
    std::vector<int>& myOffsets = allOffsets[threadIdx];
    for (int idx = startIdx; idx < endIdx; idx += 1) {
      int& offset = myOffsets[chunksAccessor[idx]];
      permAccessor[offset] = idx;
      offset += 1;
    }
  };
  std::vector<std::thread> stepTwoThreads;
  for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
    stepTwoThreads.emplace_back(
        stepTwo,
        threadIdx * numItems / numThreads,
        (threadIdx + 1) * numItems / numThreads,
        threadIdx);
  }
  for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
    stepTwoThreads[threadIdx].join();
  }
  auto stepThree = [&](int64_t startIdx, int64_t endIdx, int threadIdx) {
    CPUGeneratorType generator(
        seedIn >= 0 ? seedIn + threadIdx + numThreads
                    : at::default_rng_seed_val);
    for (int idx = startIdx; idx < endIdx - 1; idx += 1) {
      int64_t otherIdx = idx + generator.random() % (endIdx - idx);
      std::swap(permAccessor[idx], permAccessor[otherIdx]);
    }
  };
  std::vector<std::thread> stepThreeThreads;
  offset = 0;
  for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
    stepThreeThreads.emplace_back(
        stepThree, offset, allOffsets[numThreads - 1][threadIdx], threadIdx);
    offset = allOffsets[numThreads - 1][threadIdx];
  }
  for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
    stepThreeThreads[threadIdx].join();
  }
  return perm;
}