in torchaudio/csrc/decoder/src/decoder/Utils.h [65:143]
void candidatesStore(
std::vector<DecoderState>& candidates,
std::vector<DecoderState*>& candidatePtrs,
std::vector<DecoderState>& outputs,
const int beamSize,
const double threshold,
const bool logAdd,
const bool returnSorted) {
outputs.clear();
if (candidates.empty()) {
return;
}
/* 1. Select valid candidates */
for (auto& candidate : candidates) {
if (candidate.score >= threshold) {
candidatePtrs.emplace_back(&candidate);
}
}
/* 2. Merge candidates */
std::sort(
candidatePtrs.begin(),
candidatePtrs.end(),
[](const DecoderState* node1, const DecoderState* node2) {
int cmp = node1->compareNoScoreStates(node2);
return cmp == 0 ? node1->score > node2->score : cmp > 0;
});
int nHypAfterMerging = 1;
for (int i = 1; i < candidatePtrs.size(); i++) {
if (candidatePtrs[i]->compareNoScoreStates(
candidatePtrs[nHypAfterMerging - 1]) != 0) {
// Distinct candidate
candidatePtrs[nHypAfterMerging] = candidatePtrs[i];
nHypAfterMerging++;
} else {
// Same candidate
double maxScore = std::max(
candidatePtrs[nHypAfterMerging - 1]->score, candidatePtrs[i]->score);
if (logAdd) {
double minScore = std::min(
candidatePtrs[nHypAfterMerging - 1]->score,
candidatePtrs[i]->score);
candidatePtrs[nHypAfterMerging - 1]->score =
maxScore + std::log1p(std::exp(minScore - maxScore));
} else {
candidatePtrs[nHypAfterMerging - 1]->score = maxScore;
}
}
}
candidatePtrs.resize(nHypAfterMerging);
/* 3. Sort and prune */
auto compareNodeScore = [](const DecoderState* node1,
const DecoderState* node2) {
return node1->score > node2->score;
};
int nValidHyp = candidatePtrs.size();
int finalSize = std::min(nValidHyp, beamSize);
if (!returnSorted && nValidHyp > beamSize) {
std::nth_element(
candidatePtrs.begin(),
candidatePtrs.begin() + finalSize,
candidatePtrs.begin() + nValidHyp,
compareNodeScore);
} else if (returnSorted) {
std::partial_sort(
candidatePtrs.begin(),
candidatePtrs.begin() + finalSize,
candidatePtrs.begin() + nValidHyp,
compareNodeScore);
}
for (int i = 0; i < finalSize; i++) {
outputs.emplace_back(std::move(*candidatePtrs[i]));
}
}