in torchaudio/csrc/rnnt/gpu/gpu_transducer.h [210:295]
status_t ComputeAlphas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* alphas) {
const Options& options = workspace.GetOptions();
const cudaStream_t& stream = options.stream_;
const int& B = options.batchSize_;
const int& H = options.nHypos_;
const int& max_T = options.maxSrcLen_;
const int& max_U = options.maxTgtLen_;
const int& D = options.numTargets_;
const int& blank = options.blank_;
{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
/*N=*/B * H * max_T * max_U,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
if (status != SUCCESS) {
return status;
}
}
{ // compute log probability pairs (blank and target).
int num_segments =
(max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
dim3 block_dims(num_segments, max_U, B * H);
dim3 thread_dims(MAX_THREADS_PER_BLOCK);
ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
}
}
{ // compute alphas
// warp is usually a group of threads (32)
int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3 block_dims(num_warps, max_U, B * H);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for alpha only
dim3 thread_dims(WARP_SIZE, 1);
ComputeAlphasWrapper<DTYPE, CAST_DTYPE>
<<<block_dims, thread_dims, 0, stream>>>(
/*max_src_len=*/max_T,
/*max_tgt_len=*/max_U,
/*num_targets=*/D,
/*blank=*/blank,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*alpha_counters=*/workspace.GetPointerToAlphaCounters(),
/*alphas=*/(volatile DTYPE*)alphas,
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
}
}
return SUCCESS;
}