torchaudio/csrc/rnnt/gpu/gpu_transducer.h (306 lines of code) (raw):
#pragma once
#ifdef USE_CUDA
#include <torchaudio/csrc/rnnt/workspace.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh>
#include <torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh>
namespace torchaudio {
namespace rnnt {
namespace gpu {
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(
cudaError_t code,
const char* file,
int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(
stderr,
"\nGPUassert: %s %s %d\n",
cudaGetErrorString(code),
file,
line);
if (abort)
exit(code);
}
}
template <typename DTYPE, typename CAST_DTYPE>
status_t LogSumExp2D(
cudaStream_t stream,
int N,
int D,
const DTYPE* logits, // [N, D]
CAST_DTYPE* outputs) {
{ // compute max among D.
dim3 block_dims(N);
dim3 thread_dims(REDUCE_THREADS);
ReduceMax2D<REDUCE_THREADS, DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*dim=*/D,
/*inputs=*/logits,
/*outputs=*/outputs);
// BUGBUG: These error codes are only accurate when launching with
// blocking. Otherwise they usually reflect earlier errors.
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED;
}
}
{ // compute log(sum(exp(d_i - max)))
dim3 block_dims(N);
dim3 thread_dims(REDUCE_THREADS);
ReduceLogSumExpGivenMax2D<REDUCE_THREADS, DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*dim=*/D,
/*inputs=*/logits,
/*outputs=*/outputs);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED;
}
}
return SUCCESS;
}
// Inputs:
// workspace: workspace.
// logits: pointer to (B, max_T, max_U, D) logits.
// targets: pointer to (B, max_U - 1) targets in the batch.
// srcLengths: pointer to (B, ) source lengths in the batch.
// tgtLengths: pointer to (B, ) target lengths in the batch.
//
// Outputs:
// costs: pointer to (B, ) costs in the batch.
// gradients: pointer to (B, max_T, max_U, D) gradients in the batch.
template <typename DTYPE, typename CAST_DTYPE>
status_t Compute(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* gradients = nullptr) {
const Options& options = workspace.GetOptions();
const cudaStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute alphas, betas and costs.
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B * H blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 2. 1 for alpha, 1 for beta
dim3 thread_dims(WARP_SIZE, 2);
ComputeAlphasBetasCosts<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToAlphaCounters(),
/*alphas=*/workspace.GetPointerToAlphas(),
/*beta_counters=*/workspace.GetPointerToBetaCounters(),
/*betas=*/workspace.GetPointerToBetas(),
/*costs=*/costs,
/*warp_size=*/WARP_SIZE,
/*num_warps=*/num_warps,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
if (gradients != nullptr) { // compute gradients.
// don't set gradients to zero to here as gradients might reuse memory from
// logits
int num_blocks =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_blocks, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeGradients<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*clamp=*/clamp,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*gradients=*/gradients,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_GRADIENTS_FAILED;
}
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeAlphas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* alphas) {
const Options& options = workspace.GetOptions();
const cudaStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute alphas
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for alpha only
dim3 thread_dims(WARP_SIZE, 1);
ComputeAlphasWrapper<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToAlphaCounters(),
/*alphas=*/(volatile DTYPE*)alphas,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeBetas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* betas) {
const Options& options = workspace.GetOptions();
const cudaStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute betas
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for betas only
dim3 thread_dims(WARP_SIZE, 1);
ComputeBetasWrapper<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToBetaCounters(),
/*alphas=*/(volatile DTYPE*)betas,
costs,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
return SUCCESS;
}
} // namespace gpu
} // namespace rnnt
} // namespace torchaudio
#endif // USE_CUDA