in torchaudio/csrc/rnnt/cpu/cpu_transducer.h [133:180]
status_t ComputeBetas(
const Workspace<CAST_DTYPE>& workspace,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
DTYPE* costs,
DTYPE* betas) {
const Options& options = workspace.GetOptions();
CHECK_EQ(options.device_, CPU);
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
{ // compute denominators.
LogSumExp2D<DTYPE, CAST_DTYPE>(
/*N=*/B * maxT * maxU,
/*D=*/D,
/*logits=*/logits,
/*denominators=*/workspace.GetPointerToDenominators());
}
{ // compute log prob pairs.
ComputeLogProbs<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/logits,
/*targets=*/targets,
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs());
}
{ // compute betas.
ComputeBetas<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*log_probs=*/workspace.GetPointerToLogProbs(),
/*srcLengths=*/srcLengths,
/*tgtLengths=*/tgtLengths,
/*costs=*/costs,
/*betas=*/betas);
}
return SUCCESS;
}