in torchaudio/csrc/rnnt/cpu/cpu_kernels.h [113:154]
status_t ComputeLogProbs(
const Options& options,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs) {
std::vector<TensorView<const DTYPE>> seqLogits;
std::vector<const int*> seqTargets;
std::vector<TensorView<const CAST_DTYPE>> seqDenoms;
std::vector<TensorView<LogProbs<CAST_DTYPE>>> seqlogProbs;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
for (int b = 0; b < B; ++b) {
seqLogits.push_back(
TensorView<const DTYPE>({maxT, maxU, D}, logits + b * maxT * maxU * D));
seqTargets.push_back(targets + b * (maxU - 1));
seqDenoms.push_back(TensorView<const CAST_DTYPE>(
{maxT, maxU}, denominators + b * maxT * maxU));
seqlogProbs.push_back(TensorView<LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(logProbs) + b * maxT * maxU));
}
//#pragma omp parallel for
for (int b = 0; b < B; ++b) { // use max 2 * B threads.
ComputeLogProbsOneSequence<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/seqLogits[b],
/*targets=*/seqTargets[b],
/*srcLen=*/srcLengths[b],
/*tgtLen=*/tgtLengths[b] + 1, // with prepended blank.
/*denom=*/seqDenoms[b],
/*logProbs=*/seqlogProbs[b]);
}
return SUCCESS;
}