in torchbiggraph/util.cpp [127:220]
void shuffle(
at::Tensor& tensor,
const at::Tensor& permutation,
int numThreads) {
if (permutation.scalar_type() != c10::ScalarType::Long) {
throw std::invalid_argument("Permutation must have int64 dtype");
}
if (permutation.dim() != 1) {
throw std::invalid_argument("Permutation must have exactly one dimension");
}
if (tensor.dim() < 1) {
throw std::invalid_argument("Tensor must have at least one dimension");
}
int64_t numRows = tensor.sizes()[0];
if (numRows != permutation.sizes()[0]) {
throw std::invalid_argument(
"Tensor and permutation must have the same number of elements on the first dimension");
}
if (numRows == 0) {
return;
}
int64_t rowStride = tensor.strides()[0] * tensor.element_size();
if (rowStride == 0) {
return;
}
if (!tensor[0].is_contiguous()) {
throw std::invalid_argument(
"Each sub-tensor of tensor (along the first dimension) must be contiguous");
}
for (int i = 1; i < tensor.dim(); i += 1) {
if (tensor.strides()[i] == 0) {
throw std::invalid_argument(
"Tensor cannot have strides that are zero (for now)");
}
}
int64_t rowSize = tensor[0].nbytes();
// This pointer's type doesn't matter, as long as it has size 1.
uint8_t* tensorData = reinterpret_cast<uint8_t*>(tensor.data_ptr());
int64_t* permutationData = permutation.data_ptr<int64_t>();
std::vector<std::atomic_flag> checks(numRows);
std::atomic_flag* checksData = checks.data();
auto stepOne = [&](int64_t startIdx, int64_t endIdx) {
for (int64_t idx = startIdx; idx < endIdx; idx += 1) {
checksData[idx].clear();
}
};
std::vector<std::thread> stepOneThreads;
for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
stepOneThreads.emplace_back(
stepOne,
threadIdx * numRows / numThreads,
(threadIdx + 1) * numRows / numThreads);
}
for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
stepOneThreads[threadIdx].join();
}
auto stepTwo = [&](int64_t startIdx, int64_t endIdx) {
std::vector<uint8_t> bufferOne(rowSize);
std::vector<uint8_t> bufferTwo(rowSize);
void* bufferOneData = bufferOne.data();
void* bufferTwoData = bufferTwo.data();
for (int64_t baseIdx = startIdx; baseIdx < endIdx; baseIdx += 1) {
int64_t curIdx = baseIdx;
std::memcpy(bufferOneData, tensorData + curIdx * rowStride, rowSize);
if (checksData[curIdx].test_and_set()) {
continue;
}
bool done = false;
while (!done) {
curIdx = permutationData[curIdx];
if (curIdx < 0 || curIdx >= numRows) {
throw std::invalid_argument("Permutation has out-of-bound values");
}
std::memcpy(bufferTwoData, tensorData + curIdx * rowStride, rowSize);
done = checksData[curIdx].test_and_set();
std::memcpy(tensorData + curIdx * rowStride, bufferOneData, rowSize);
std::swap(bufferOneData, bufferTwoData);
}
}
};
std::vector<std::thread> stepTwoThreads;
for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
stepTwoThreads.emplace_back(
stepTwo,
threadIdx * numRows / numThreads,
(threadIdx + 1) * numRows / numThreads);
}
for (int threadIdx = 0; threadIdx < numThreads; threadIdx += 1) {
stepTwoThreads[threadIdx].join();
}
}