status_t Compute()

in torchaudio/csrc/rnnt/cpu/cpu_transducer.h [21:82]


status_t Compute(
    const Workspace<CAST_DTYPE>& workspace,
    const DTYPE* logits,
    const int* targets,
    const int* srcLengths,
    const int* tgtLengths,
    DTYPE* costs,
    DTYPE* gradients = nullptr) {
  const Options& options = workspace.GetOptions();

  CHECK_EQ(options.device_, CPU);

  const int& B = options.batchSize_;
  const int& maxT = options.maxSrcLen_;
  const int& maxU = options.maxTgtLen_;
  const int& D = options.numTargets_;

  { // compute denominators.
    LogSumExp2D<DTYPE, CAST_DTYPE>(
        /*N=*/B * maxT * maxU,
        /*D=*/D,
        /*logits=*/logits,
        /*denominators=*/workspace.GetPointerToDenominators());
  }

  { // compute log prob pairs.
    ComputeLogProbs<DTYPE, CAST_DTYPE>(
        /*options=*/options,
        /*logits=*/logits,
        /*targets=*/targets,
        /*srcLengths=*/srcLengths,
        /*tgtLengths=*/tgtLengths,
        /*denominators=*/workspace.GetPointerToDenominators(),
        /*log_probs=*/workspace.GetPointerToLogProbs());
  }

  { // compute alphas and betas.
    ComputeAlphasBetas<DTYPE, CAST_DTYPE>(
        /*options=*/options,
        /*log_probs=*/workspace.GetPointerToLogProbs(),
        /*srcLengths=*/srcLengths,
        /*tgtLengths=*/tgtLengths,
        /*alphas=*/workspace.GetPointerToAlphas(),
        /*betas=*/workspace.GetPointerToBetas(),
        /*costs=*/costs);
  }

  if (gradients != nullptr) {
    ComputeGradients<DTYPE, CAST_DTYPE>(
        /*options=*/options,
        /*logits=*/logits,
        /*targets=*/targets,
        /*srcLengths=*/srcLengths,
        /*tgtLengths=*/tgtLengths,
        /*denominators=*/workspace.GetPointerToDenominators(),
        /*alphas=*/workspace.GetPointerToAlphas(),
        /*betas=*/workspace.GetPointerToBetas(),
        /*gradients=*/gradients);
  }

  return SUCCESS;
}