status_t ComputeLogProbs()

in torchaudio/csrc/rnnt/cpu/cpu_kernels.h [113:154]


status_t ComputeLogProbs(
    const Options& options,
    const DTYPE* logits,
    const int* targets,
    const int* srcLengths,
    const int* tgtLengths,
    const CAST_DTYPE* denominators,
    CAST_DTYPE* logProbs) {
  std::vector<TensorView<const DTYPE>> seqLogits;
  std::vector<const int*> seqTargets;
  std::vector<TensorView<const CAST_DTYPE>> seqDenoms;
  std::vector<TensorView<LogProbs<CAST_DTYPE>>> seqlogProbs;

  const int& B = options.batchSize_;
  const int& maxT = options.maxSrcLen_;
  const int& maxU = options.maxTgtLen_;
  const int& D = options.numTargets_;
  for (int b = 0; b < B; ++b) {
    seqLogits.push_back(
        TensorView<const DTYPE>({maxT, maxU, D}, logits + b * maxT * maxU * D));
    seqTargets.push_back(targets + b * (maxU - 1));
    seqDenoms.push_back(TensorView<const CAST_DTYPE>(
        {maxT, maxU}, denominators + b * maxT * maxU));
    seqlogProbs.push_back(TensorView<LogProbs<CAST_DTYPE>>(
        {maxT, maxU},
        reinterpret_cast<LogProbs<CAST_DTYPE>*>(logProbs) + b * maxT * maxU));
  }

  //#pragma omp parallel for
  for (int b = 0; b < B; ++b) { // use max 2 * B threads.
    ComputeLogProbsOneSequence<DTYPE, CAST_DTYPE>(
        /*options=*/options,
        /*logits=*/seqLogits[b],
        /*targets=*/seqTargets[b],
        /*srcLen=*/srcLengths[b],
        /*tgtLen=*/tgtLengths[b] + 1, // with prepended blank.
        /*denom=*/seqDenoms[b],
        /*logProbs=*/seqlogProbs[b]);
  }

  return SUCCESS;
}