void ComputeAlphasBetas()

in torchaudio/csrc/rnnt/cpu/cpu_kernels.h [249:293]


void ComputeAlphasBetas(
    const Options& options,
    const CAST_DTYPE* logProbs,
    const int* srcLengths,
    const int* tgtLengths,
    CAST_DTYPE* alphas,
    CAST_DTYPE* betas,
    DTYPE* costs) {
  std::vector<TensorView<const LogProbs<CAST_DTYPE>>> seqlogProbs;
  std::vector<TensorView<CAST_DTYPE>> seq_alphas;
  std::vector<TensorView<CAST_DTYPE>> seq_betas;

  const int& B = options.batchSize_;
  const int& maxT = options.maxSrcLen_;
  const int& maxU = options.maxTgtLen_;

  for (int b = 0; b < B; ++b) {
    seqlogProbs.push_back(TensorView<const LogProbs<CAST_DTYPE>>(
        {maxT, maxU},
        reinterpret_cast<LogProbs<CAST_DTYPE>*>(
            const_cast<CAST_DTYPE*>(logProbs)) +
            b * maxT * maxU));
    seq_alphas.push_back(
        TensorView<CAST_DTYPE>({maxT, maxU}, alphas + b * maxT * maxU));
    seq_betas.push_back(
        TensorView<CAST_DTYPE>({maxT, maxU}, betas + b * maxT * maxU));
  }

  std::vector<CAST_DTYPE> scores(B << 1);
  //#pragma omp parallel for
  for (int t = 0; t < (B << 1); ++t) { // use max 2 * B threads.
    int i = (t >> 1);
    scores[t] = ComputeAlphaOrBetaOneSequence<CAST_DTYPE>(
        /*thread=*/t,
        /*options=*/options,
        /*logProbs=*/seqlogProbs[i],
        /*srcLen=*/srcLengths[i],
        /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
        /*alpha=*/seq_alphas[i],
        /*beta=*/seq_betas[i]);
  }
  for (int b = 0; b < B; ++b) {
    costs[b] = -scores[b << 1];
  }
}