status_t ComputeBetas()

in torchaudio/csrc/rnnt/cpu/cpu_transducer.h [133:180]


status_t ComputeBetas(
    const Workspace<CAST_DTYPE>& workspace,
    const DTYPE* logits,
    const int* targets,
    const int* srcLengths,
    const int* tgtLengths,
    DTYPE* costs,
    DTYPE* betas) {
  const Options& options = workspace.GetOptions();

  CHECK_EQ(options.device_, CPU);

  const int& B = options.batchSize_;
  const int& maxT = options.maxSrcLen_;
  const int& maxU = options.maxTgtLen_;
  const int& D = options.numTargets_;

  { // compute denominators.
    LogSumExp2D<DTYPE, CAST_DTYPE>(
        /*N=*/B * maxT * maxU,
        /*D=*/D,
        /*logits=*/logits,
        /*denominators=*/workspace.GetPointerToDenominators());
  }

  { // compute log prob pairs.
    ComputeLogProbs<DTYPE, CAST_DTYPE>(
        /*options=*/options,
        /*logits=*/logits,
        /*targets=*/targets,
        /*srcLengths=*/srcLengths,
        /*tgtLengths=*/tgtLengths,
        /*denominators=*/workspace.GetPointerToDenominators(),
        /*log_probs=*/workspace.GetPointerToLogProbs());
  }

  { // compute betas.
    ComputeBetas<DTYPE, CAST_DTYPE>(
        /*options=*/options,
        /*log_probs=*/workspace.GetPointerToLogProbs(),
        /*srcLengths=*/srcLengths,
        /*tgtLengths=*/tgtLengths,
        /*costs=*/costs,
        /*betas=*/betas);
  }

  return SUCCESS;
}