void shuffle()

in torchbiggraph/util.cpp [127:220]


void shuffle(
    at::Tensor& tensor,
    const at::Tensor& permutation,
    int numThreads) {
  if (permutation.scalar_type() != c10::ScalarType::Long) {
    throw std::invalid_argument("Permutation must have int64 dtype");
  }
  if (permutation.dim() != 1) {
    throw std::invalid_argument("Permutation must have exactly one dimension");
  }
  if (tensor.dim() < 1) {
    throw std::invalid_argument("Tensor must have at least one dimension");
  }
  int64_t numRows = tensor.sizes()[0];
  if (numRows != permutation.sizes()[0]) {
    throw std::invalid_argument(
        "Tensor and permutation must have the same number of elements on the first dimension");
  }
  if (numRows == 0) {
    return;
  }
  int64_t rowStride = tensor.strides()[0] * tensor.element_size();
  if (rowStride == 0) {
    return;
  }
  if (!tensor[0].is_contiguous()) {
    throw std::invalid_argument(
        "Each sub-tensor of tensor (along the first dimension) must be contiguous");
  }
  for (int i = 1; i < tensor.dim(); i += 1) {
    if (tensor.strides()[i] == 0) {
      throw std::invalid_argument(
          "Tensor cannot have strides that are zero (for now)");
    }
  }
  int64_t rowSize = tensor[0].nbytes();

  // This pointer's type doesn't matter, as long as it has size 1.
  uint8_t* tensorData = reinterpret_cast<uint8_t*>(tensor.data_ptr());
  int64_t* permutationData = permutation.data_ptr<int64_t>();

  std::vector<std::atomic_flag> checks(numRows);
  std::atomic_flag* checksData = checks.data();

  auto stepOne = [&](int64_t startIdx, int64_t endIdx) {
    for (int64_t idx = startIdx; idx < endIdx; idx += 1) {
      checksData[idx].clear();
    }
  };
  std::vector<std::thread> stepOneThreads;
  for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
    stepOneThreads.emplace_back(
        stepOne,
        threadIdx * numRows / numThreads,
        (threadIdx + 1) * numRows / numThreads);
  }
  for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
    stepOneThreads[threadIdx].join();
  }
  auto stepTwo = [&](int64_t startIdx, int64_t endIdx) {
    std::vector<uint8_t> bufferOne(rowSize);
    std::vector<uint8_t> bufferTwo(rowSize);
    void* bufferOneData = bufferOne.data();
    void* bufferTwoData = bufferTwo.data();
    for (int64_t baseIdx = startIdx; baseIdx < endIdx; baseIdx += 1) {
      int64_t curIdx = baseIdx;
      std::memcpy(bufferOneData, tensorData + curIdx * rowStride, rowSize);
      if (checksData[curIdx].test_and_set()) {
        continue;
      }
      bool done = false;
      while (!done) {
        curIdx = permutationData[curIdx];
        if (curIdx < 0 || curIdx >= numRows) {
          throw std::invalid_argument("Permutation has out-of-bound values");
        }
        std::memcpy(bufferTwoData, tensorData + curIdx * rowStride, rowSize);
        done = checksData[curIdx].test_and_set();
        std::memcpy(tensorData + curIdx * rowStride, bufferOneData, rowSize);
        std::swap(bufferOneData, bufferTwoData);
      }
    }
  };
  std::vector<std::thread> stepTwoThreads;
  for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
    stepTwoThreads.emplace_back(
        stepTwo,
        threadIdx * numRows / numThreads,
        (threadIdx + 1) * numRows / numThreads);
  }
  for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
    stepTwoThreads[threadIdx].join();
  }
}