include/cuda_utils.cuh (137 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
#pragma once
#include <ATen/ATen.h>
#include <c10/util/Optional.h>
#if defined(__HIP_PLATFORM_HCC__)
constexpr int WARP_SIZE = 64;
#else
constexpr int WARP_SIZE = 32;
#endif
// The maximum number of threads in a block
#if defined(__HIP_PLATFORM_HCC__)
constexpr int MAX_BLOCK_SIZE = 256;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(__HIP_PLATFORM_HCC__)
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif
for (int i = 0; i != 5; ++i) {
if (nElem <= threadSizes[i]) {
return threadSizes[i];
}
}
return MAX_BLOCK_SIZE;
}
static int lastPow2(unsigned int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return n - (n >> 1);
}
// Returns the index of the most significant 1 bit in `val`.
__device__ __forceinline__ int getMSB(int val) {
return 31 - __clz(val);
}
// Sum across all threads within a warp
template <typename T>
static __device__ __forceinline__ T warpSum(T val) {
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
}
return val;
}
template<
typename scalar_t,
int64_t dim,
template <typename U> class PtrTraits = at::DefaultPtrTraits,
typename index_t = int64_t>
static at::PackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(
const c10::optional<at::Tensor>& t) {
if (!t.has_value()) {
const std::vector<index_t> zeros(dim);
return at::PackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data());
}
return t.value().packed_accessor<scalar_t, dim, PtrTraits, index_t>();
}
template<typename scalar_t>
struct Float2 {
scalar_t v1, v2;
__device__ Float2() {}
__device__ Float2(scalar_t v1, scalar_t v2) : v1(v1), v2(v2) {}
__device__ Float2(int v) : v1(static_cast<scalar_t>(v)), v2(static_cast<scalar_t>(v)) {}
__device__ Float2& operator+=(const Float2& a) {
v1 += a.v1;
v2 += a.v2;
return *this;
}
};
template <typename scalar_t>
static __device__ __forceinline__ Float2<scalar_t> warpSum(Float2<scalar_t> value) {
value.v1 = warpSum(value.v1);
value.v2 = warpSum(value.v2);
return value;
}
// Sum across (batch, x/y/z) applying Op() pointwise
// this works by first having each thread sum it's part
// of the data. Then there is a double-shuffeling reduction.
// First each warp (of WARP_SIZE threads) uses warpSum to reduce its
// data to the "warp leader", who writes its value into shared memory.
// Then a single warp reads the remaining (at most WARP_SIZE) items
// and reduces them using another warpSum.
// The implicit assumption is that there are no more
// than WARP_SIZE**2 threads.
template<typename scalar_t, typename Op, typename PTA>
__device__ scalar_t reduce(Op op, PTA tensor, int plane) {
// first the reductions each thread does separately
scalar_t sum = static_cast<scalar_t>(0);
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
sum += op(batch, plane, x);
}
}
// first warpSum to get one value per thread to
// one value per warp
sum = warpSum(sum);
// this writes each warps item into shared memory
// there are at most WARP_SIZE items left because
// there are at most WARP_SIZE**2 threads at the beginning
__shared__ scalar_t shared[WARP_SIZE];
__syncthreads();
int tid = threadIdx.x + threadIdx.y * blockDim.x;
if (tid % WARP_SIZE == 0) {
shared[tid / WARP_SIZE] = sum;
}
if (tid >= blockDim.x * blockDim.y / WARP_SIZE && tid < WARP_SIZE) {
// zero out the other entries in shared
shared[tid] = (scalar_t)0;
}
__syncthreads();
// now have a second warpSum to reduce the intermediate values
// from shared memory to a single number. The very first
// thread writes it to shared memory.
if (tid / WARP_SIZE == 0) {
sum = warpSum(shared[tid]);
if (tid == 0) {
shared[0] = sum;
}
}
__syncthreads();
// Everyone picks it up, should be broadcast into the whole grad_input
return shared[0];
}