status_t ComputeAlphas()

in torchaudio/csrc/rnnt/gpu/gpu_transducer.h [210:295]


status_t ComputeAlphas(
    const Workspace<CAST_DTYPE>& workspace,
    const DTYPE* logits,
    const int* targets,
    const int* srcLengths,
    const int* tgtLengths,
    DTYPE* alphas) {
  const Options& options = workspace.GetOptions();

  const cudaStream_t& stream = options.stream_;
  const int& B = options.batchSize_;
  const int& H = options.nHypos_;
  const int& max_T = options.maxSrcLen_;
  const int& max_U = options.maxTgtLen_;
  const int& D = options.numTargets_;
  const int& blank = options.blank_;

  { // compute denominators.
    status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
        /*stream=*/stream,
        /*N=*/B * H * max_T * max_U,
        /*D=*/D,
        /*logits=*/logits,
        /*denominators=*/workspace.GetPointerToDenominators());

    if (status != SUCCESS) {
      return status;
    }
  }

  { // compute log probability pairs (blank and target).
    int num_segments =
        (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK;
    dim3 block_dims(num_segments, max_U, B * H);
    dim3 thread_dims(MAX_THREADS_PER_BLOCK);

    ComputeLogProbs<DTYPE, CAST_DTYPE><<<block_dims, thread_dims, 0, stream>>>(
        /*max_src_len=*/max_T,
        /*max_tgt_len=*/max_U,
        /*num_targets=*/D,
        /*blank=*/blank,
        /*logits=*/logits,
        /*targets=*/targets,
        /*srcLengths=*/srcLengths,
        /*tgtLengths=*/tgtLengths,
        /*denominators=*/workspace.GetPointerToDenominators(),
        /*log_probs=*/workspace.GetPointerToLogProbs(),
        H);

    if (cudaGetLastError() != cudaSuccess) {
      return COMPUTE_LOG_PROBS_FAILED;
    }
  }
  { // compute alphas
    // warp is usually a group of threads (32)
    int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE;

    // each block is identified by 3 d tuple.
    // we are using num_warp * max_U * B blocks
    // where num_warp is division among Time axis
    dim3 block_dims(num_warps, max_U, B * H);

    // each thread is identified by a 2 d tuple
    // 2nd dim is 1 for alpha only
    dim3 thread_dims(WARP_SIZE, 1);

    ComputeAlphasWrapper<DTYPE, CAST_DTYPE>
        <<<block_dims, thread_dims, 0, stream>>>(
            /*max_src_len=*/max_T,
            /*max_tgt_len=*/max_U,
            /*num_targets=*/D,
            /*blank=*/blank,
            /*log_probs=*/workspace.GetPointerToLogProbs(),
            /*srcLengths=*/srcLengths,
            /*tgtLengths=*/tgtLengths,
            /*alpha_counters=*/workspace.GetPointerToAlphaCounters(),
            /*alphas=*/(volatile DTYPE*)alphas,
            H);

    if (cudaGetLastError() != cudaSuccess) {
      return COMPUTE_ALPHAS_BETAS_COSTS_FAILED;
    }
  }

  return SUCCESS;
}