torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh (79 lines of code) (raw):

#pragma once #ifdef USE_CUDA #include <torchaudio/csrc/rnnt/gpu/math.cuh> namespace torchaudio { namespace rnnt { template <int NUM_THREADS, typename DTYPE, typename CAST_DTYPE> __global__ void ReduceMax2D( int dim, const DTYPE* inputs, // [N, dim] CAST_DTYPE* outputs) { __shared__ CAST_DTYPE shared[NUM_THREADS]; // each thread reduces one matrix row int offset = blockIdx.x * dim; // [n, 0] CAST_DTYPE val = inputs[offset]; // default = inputs(n, 0) for (int d = threadIdx.x; d < dim; d += NUM_THREADS) { CAST_DTYPE next = inputs[offset + d]; if (next > val) { val = next; } } shared[threadIdx.x] = val; __syncthreads(); for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) { if (threadIdx.x < stride && threadIdx.x + stride < dim) { if (shared[threadIdx.x + stride] > shared[threadIdx.x]) { shared[threadIdx.x] = shared[threadIdx.x + stride]; val = shared[threadIdx.x]; } } __syncthreads(); } CAST_DTYPE shf; for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { shf = __shfl_down_sync(0xFFFFFFFF, val, stride); if (threadIdx.x < stride && threadIdx.x + stride < dim) { if (shf > val) { val = shf; } } } if (threadIdx.x == 0) { outputs[blockIdx.x] = val; } } template <int NUM_THREADS, typename DTYPE, typename CAST_DTYPE> __global__ void ReduceLogSumExpGivenMax2D( int dim, const DTYPE* inputs, // [N, dim] CAST_DTYPE* outputs) { // in: max -> out: logsum __shared__ CAST_DTYPE shared[NUM_THREADS]; CAST_DTYPE max = outputs[blockIdx.x]; CAST_DTYPE val = 0; int offset = blockIdx.x * dim; for (int d = threadIdx.x; d < dim; d += NUM_THREADS) { val = val + std::exp(CAST_DTYPE(inputs[offset + d]) - max); } shared[threadIdx.x] = val; __syncthreads(); for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) { if (threadIdx.x < stride && threadIdx.x + stride < dim) { val = shared[threadIdx.x] + shared[threadIdx.x + stride]; shared[threadIdx.x] = val; } __syncthreads(); } CAST_DTYPE shf; for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { shf = __shfl_down_sync(0xFFFFFFFF, val, stride); if (threadIdx.x < stride && threadIdx.x + stride < dim) { val = val + shf; } } if (threadIdx.x == 0) { outputs[blockIdx.x] = max + std::log(val); } } } // namespace rnnt } // namespace torchaudio #endif // USE_CUDA