include/cuda_utils.cuh (137 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. #pragma once #include <ATen/ATen.h> #include <c10/util/Optional.h> #if defined(__HIP_PLATFORM_HCC__) constexpr int WARP_SIZE = 64; #else constexpr int WARP_SIZE = 32; #endif // The maximum number of threads in a block #if defined(__HIP_PLATFORM_HCC__) constexpr int MAX_BLOCK_SIZE = 256; #else constexpr int MAX_BLOCK_SIZE = 512; #endif template <typename T> __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 return __shfl_xor_sync(mask, value, laneMask, width); #else return __shfl_xor(value, laneMask, width); #endif } // Number of threads in a block given an input size up to MAX_BLOCK_SIZE static int getNumThreads(int nElem) { #if defined(__HIP_PLATFORM_HCC__) int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; #else int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; #endif for (int i = 0; i != 5; ++i) { if (nElem <= threadSizes[i]) { return threadSizes[i]; } } return MAX_BLOCK_SIZE; } static int lastPow2(unsigned int n) { n |= (n >> 1); n |= (n >> 2); n |= (n >> 4); n |= (n >> 8); n |= (n >> 16); return n - (n >> 1); } // Returns the index of the most significant 1 bit in `val`. __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } // Sum across all threads within a warp template <typename T> static __device__ __forceinline__ T warpSum(T val) { for (int i = 0; i < getMSB(WARP_SIZE); ++i) { val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); } return val; } template< typename scalar_t, int64_t dim, template <typename U> class PtrTraits = at::DefaultPtrTraits, typename index_t = int64_t> static at::PackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy( const c10::optional<at::Tensor>& t) { if (!t.has_value()) { const std::vector<index_t> zeros(dim); return at::PackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data()); } return t.value().packed_accessor<scalar_t, dim, PtrTraits, index_t>(); } template<typename scalar_t> struct Float2 { scalar_t v1, v2; __device__ Float2() {} __device__ Float2(scalar_t v1, scalar_t v2) : v1(v1), v2(v2) {} __device__ Float2(int v) : v1(static_cast<scalar_t>(v)), v2(static_cast<scalar_t>(v)) {} __device__ Float2& operator+=(const Float2& a) { v1 += a.v1; v2 += a.v2; return *this; } }; template <typename scalar_t> static __device__ __forceinline__ Float2<scalar_t> warpSum(Float2<scalar_t> value) { value.v1 = warpSum(value.v1); value.v2 = warpSum(value.v2); return value; } // Sum across (batch, x/y/z) applying Op() pointwise // this works by first having each thread sum it's part // of the data. Then there is a double-shuffeling reduction. // First each warp (of WARP_SIZE threads) uses warpSum to reduce its // data to the "warp leader", who writes its value into shared memory. // Then a single warp reads the remaining (at most WARP_SIZE) items // and reduces them using another warpSum. // The implicit assumption is that there are no more // than WARP_SIZE**2 threads. template<typename scalar_t, typename Op, typename PTA> __device__ scalar_t reduce(Op op, PTA tensor, int plane) { // first the reductions each thread does separately scalar_t sum = static_cast<scalar_t>(0); for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { sum += op(batch, plane, x); } } // first warpSum to get one value per thread to // one value per warp sum = warpSum(sum); // this writes each warps item into shared memory // there are at most WARP_SIZE items left because // there are at most WARP_SIZE**2 threads at the beginning __shared__ scalar_t shared[WARP_SIZE]; __syncthreads(); int tid = threadIdx.x + threadIdx.y * blockDim.x; if (tid % WARP_SIZE == 0) { shared[tid / WARP_SIZE] = sum; } if (tid >= blockDim.x * blockDim.y / WARP_SIZE && tid < WARP_SIZE) { // zero out the other entries in shared shared[tid] = (scalar_t)0; } __syncthreads(); // now have a second warpSum to reduce the intermediate values // from shared memory to a single number. The very first // thread writes it to shared memory. if (tid / WARP_SIZE == 0) { sum = warpSum(shared[tid]); if (tid == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole grad_input return shared[0]; }