torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh (349 lines of code) (raw):

#pragma once #ifdef USE_CUDA #include <cassert> #include <torchaudio/csrc/rnnt/gpu/kernel_utils.h> #include <torchaudio/csrc/rnnt/gpu/kernels.h> #include <torchaudio/csrc/rnnt/gpu/math.cuh> namespace torchaudio { namespace rnnt { template <typename DTYPE, typename CAST_DTYPE> __global__ void ComputeLogProbs( int maxSrcLen, int maxTgtLen, int numTargets, int blank, const DTYPE* logits, const int* targets, const int* srcLengths, const int* tgtLengths, const CAST_DTYPE* denominators, CAST_DTYPE* logProbs, int H = 1) { const int& maxT = maxSrcLen; const int& maxU = maxTgtLen; const int& D = numTargets; const int bTgt = blockIdx.z; // 0 <= b < B const int bSrc = bTgt / H; const int T = srcLengths[bSrc]; const int U = tgtLengths[bTgt] + 1; const int t = blockIdx.x * blockDim.x + threadIdx.x; const int u = blockIdx.y; if (t >= T || u >= U) { // out of boundary. return; } Indexer3D indexer(maxT, maxU); int idx = indexer(bTgt, t, u); // skip: log_prob(b, t, u).skip() = logits(b, t, u, blank) - denom(b, t, u). logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] = CAST_DTYPE(logits[idx * D + blank]) - denominators[idx]; if (u < U - 1) { // emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t, // u). int target = targets[Indexer2D(maxU - 1)(bTgt, u)]; logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] = CAST_DTYPE(logits[idx * D + target]) - denominators[idx]; } } template <typename DTYPE, typename CAST_DTYPE> __device__ void ComputeAlphas( int maxSrcLen, int maxTgtLen, int numTargets, int blank, const CAST_DTYPE* logProbs, const int* srcLengths, const int* tgtLengths, int* alpha_counters, volatile CAST_DTYPE* alphas, int H = 1) { const int& maxT = maxSrcLen; const int& maxU = maxTgtLen; const int bTgt = blockIdx.z; // 0 <= b < B const int bSrc = bTgt / H; const int T = srcLengths[bSrc]; const int U = tgtLengths[bTgt] + 1; const int t = blockIdx.x * blockDim.x + threadIdx.x + 1; const int u = blockIdx.y + 1; if (t >= T || u >= U) { // out of boundary. return; } int* counter = alpha_counters + Indexer2D(maxU)(bTgt, blockIdx.y); Indexer3D idxr(maxT, maxU); if (t == 1 && u == 1) { alphas[idxr(bTgt, 0, 0)] = 0; } if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. while (atomicAdd(counter, 0) < blockIdx.x) { } } if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. while (atomicAdd(counter - 1, 0) <= blockIdx.x) { } } if (t == 1 && u < U) { // alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit(). alphas[idxr(bTgt, 0, u)] = alphas[idxr(bTgt, 0, u - 1)] + logProbs[(idxr(bTgt, 0, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; } if (blockIdx.y == 0 && t < T) { CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t - 1, 0) << 1) + LOG_PROBS_SKIP_IDX]; CAST_DTYPE val; #pragma unroll for (int i = 1; i < warpSize; i <<= 1) { val = __shfl_up_sync(0xffffffff, skip_prob, i); if (i <= threadIdx.x) { skip_prob = skip_prob + val; } } val = alphas[idxr(bTgt, blockIdx.x * blockDim.x, 0)]; alphas[idxr(bTgt, t, 0)] = skip_prob + val; } if (t < T && u < U) { CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t - 1, u) << 1) + LOG_PROBS_SKIP_IDX]; CAST_DTYPE emit_prob = logProbs[(idxr(bTgt, t, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; CAST_DTYPE skip = alphas[idxr(bTgt, blockIdx.x * blockDim.x, u)] + skip_prob; CAST_DTYPE emit = alphas[idxr(bTgt, t, u - 1)] + emit_prob; CAST_DTYPE val = math::lse(skip, emit); CAST_DTYPE out = val; for (int i = 1; i < warpSize; ++i) { val = __shfl_up_sync(0xffffffff, val, 1); if (i == threadIdx.x) { val = math::lse(val + skip_prob, emit); out = val; } } alphas[idxr(bTgt, t, u)] = out; } if (threadIdx.x == 0) { __threadfence(); atomicAdd(counter, 1); } } template <typename DTYPE, typename CAST_DTYPE> __device__ void ComputeBetasCosts( int maxSrcLen, int maxTgtLen, int numTargets, int blank, const CAST_DTYPE* logProbs, const int* srcLengths, const int* tgtLengths, int* betaCounters, volatile CAST_DTYPE* betas, DTYPE* costs, int H = 1) { const int& maxT = maxSrcLen; const int& maxU = maxTgtLen; const int bTgt = blockIdx.z; // 0 <= b < B const int bSrc = bTgt / H; const int T = srcLengths[bSrc]; const int U = tgtLengths[bTgt] + 1; const int t = T - 2 - blockIdx.x * blockDim.x - threadIdx.x; const int u = U - 2 - blockIdx.y; if (t < 0 || u < 0) { // out of boundary. return; } int* counter = betaCounters + Indexer2D(maxU)(bTgt, blockIdx.y); Indexer3D idxr(maxT, maxU); if (t == T - 2 && u == U - 2) { betas[idxr(bTgt, T - 1, U - 1)] = logProbs[(idxr(bTgt, T - 1, U - 1) << 1) + LOG_PROBS_SKIP_IDX]; } if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. while (atomicAdd(counter, 0) < blockIdx.x) { } } if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. while (atomicAdd(counter - 1, 0) <= blockIdx.x) { } } if (t == T - 2 && u >= 0) { betas[idxr(bTgt, T - 1, u)] = betas[idxr(bTgt, T - 1, u + 1)] + logProbs[(idxr(bTgt, T - 1, u) << 1) + LOG_PROBS_EMIT_IDX]; } if (blockIdx.y == 0 && t >= 0) { CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t, U - 1) << 1) + LOG_PROBS_SKIP_IDX]; CAST_DTYPE val; #pragma unroll for (int i = 1; i < warpSize; i <<= 1) { val = __shfl_up_sync(0xffffffff, skip_prob, i); if (i <= threadIdx.x) { skip_prob = skip_prob + val; } } betas[idxr(bTgt, t, U - 1)] = betas[idxr(bTgt, T - 1 - blockIdx.x * blockDim.x, U - 1)] + skip_prob; } if (t >= 0 && u >= 0) { CAST_DTYPE skip_prob = logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_SKIP_IDX]; CAST_DTYPE emit_prob = logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_EMIT_IDX]; CAST_DTYPE skip = betas[idxr(bTgt, t + threadIdx.x + 1, u)] + skip_prob; CAST_DTYPE emit = betas[idxr(bTgt, t, u + 1)] + emit_prob; CAST_DTYPE val = math::lse(skip, emit); CAST_DTYPE out = val; for (int i = 1; i < warpSize; ++i) { val = __shfl_up_sync(0xffffffff, val, 1); if (i == threadIdx.x) { val = math::lse(val + skip_prob, emit); out = val; } } betas[idxr(bTgt, t, u)] = out; if (t == 0 && u == 0) { // use -beta(0, 0) as cost. costs[bTgt] = DTYPE(-out); } } if (threadIdx.x == 0) { __threadfence(); atomicAdd(counter, 1); } } template <typename DTYPE, typename CAST_DTYPE> __global__ void ComputeAlphasBetasCosts( int maxSrcLen, int maxTgtLen, int numTargets, int blank, const CAST_DTYPE* logProbs, const int* srcLengths, const int* tgtLengths, int* alpha_counters, volatile CAST_DTYPE* alphas, int* betaCounters, volatile CAST_DTYPE* betas, DTYPE* costs, int warpSize = 0, int numWarps = 0, int H = 1) { assert(threadIdx.y == 0 || threadIdx.y == 1); if (threadIdx.y == 0) { ComputeAlphas<DTYPE, CAST_DTYPE>( /*maxSrcLen=*/maxSrcLen, /*maxTgtLen=*/maxTgtLen, /*numTargets=*/numTargets, /*blank=*/blank, /*logProbs=*/logProbs, /*srcLengths=*/srcLengths, /*tgtLengths=*/tgtLengths, /*alpha_counters=*/alpha_counters, /*alphas=*/alphas, H); } else { // threadIdx.y == 1 ComputeBetasCosts<DTYPE, CAST_DTYPE>( /*maxSrcLen=*/maxSrcLen, /*maxTgtLen=*/maxTgtLen, /*numTargets=*/numTargets, /*blank=*/blank, /*logProbs=*/logProbs, /*srcLengths=*/srcLengths, /*tgtLengths=*/tgtLengths, /*betaCounters=*/betaCounters, /*beta=*/betas, /*costs=*/costs, H); } } template <typename DTYPE, typename CAST_DTYPE> __global__ void ComputeGradients( int maxSrcLen, int maxTgtLen, int numTargets, int blank, CAST_DTYPE clamp, const DTYPE* logits, const int* targets, const int* srcLengths, const int* tgtLengths, const CAST_DTYPE* denominators, const CAST_DTYPE* alphas, const CAST_DTYPE* betas, DTYPE* gradients, int H = 1) { const int bTgt = blockIdx.z; // 0 <= b < B const int t = blockIdx.x * blockDim.x + threadIdx.x; const int u = blockIdx.y; ComputeGradientsElement( bTgt, t, u, maxSrcLen, maxTgtLen, numTargets, blank, clamp, logits, targets, srcLengths, tgtLengths, denominators, alphas, betas, gradients, H); } // This is a __global__ wrapper around ComputeAlphas // device kernel to enable unit testing template <typename DTYPE, typename CAST_DTYPE> __global__ void ComputeAlphasWrapper( int maxSrcLen, int maxTgtLen, int numTargets, int blank, const CAST_DTYPE* logProbs, const int* srcLengths, const int* tgtLengths, int* alpha_counters, volatile CAST_DTYPE* alphas, int H = 1) { ComputeAlphas<DTYPE, CAST_DTYPE>( maxSrcLen, maxTgtLen, numTargets, blank, logProbs, srcLengths, tgtLengths, alpha_counters, alphas, H); } // This is a __global__ wrapper around ComputeBetas // device kernel to enable unit testing template <typename DTYPE, typename CAST_DTYPE> __global__ void ComputeBetasWrapper( int maxSrcLen, int maxTgtLen, int numTargets, int blank, const CAST_DTYPE* logProbs, const int* srcLengths, const int* tgtLengths, int* betaCounters, volatile CAST_DTYPE* betas, DTYPE* costs, int H = 1) { ComputeBetasCosts<DTYPE, CAST_DTYPE>( maxSrcLen, maxTgtLen, numTargets, blank, logProbs, srcLengths, tgtLengths, betaCounters, betas, costs, H); } // #undef LOG_PROBS_SKIP_IDX // #undef LOG_PROBS_EMIT_IDX } // namespace rnnt } // namespace torchaudio #endif // USE_CUDA