in torchaudio/csrc/decoder/src/decoder/LexiconDecoder.cpp [32:229]
void LexiconDecoder::decodeStep(const float* emissions, int T, int N) {
int startFrame = nDecodedFrames_ - nPrunedFrames_;
// Extend hyp_ buffer
if (hyp_.size() < startFrame + T + 2) {
for (int i = hyp_.size(); i < startFrame + T + 2; i++) {
hyp_.emplace(i, std::vector<LexiconDecoderState>());
}
}
std::vector<size_t> idx(N);
for (int t = 0; t < T; t++) {
std::iota(idx.begin(), idx.end(), 0);
if (N > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&t, &N, &emissions](const size_t& l, const size_t& r) {
return emissions[t * N + l] > emissions[t * N + r];
});
}
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconDecoderState& prevHyp : hyp_[startFrame + t]) {
const TrieNode* prevLex = prevHyp.lex;
const int prevIdx = prevHyp.token;
const float lexMaxScore =
prevLex == lexicon_->getRoot() ? 0 : prevLex->maxScore;
/* (1) Try children */
for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) {
int n = idx[r];
auto iter = prevLex->children.find(n);
if (iter == prevLex->children.end()) {
continue;
}
const TrieNodePtr& lex = iter->second;
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + amScore;
if (n == sil_) {
score += opt_.silScore;
}
LMStatePtr lmState;
double lmScore = 0.;
if (isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second;
}
// We eat-up a new token
if (opt_.criterionType != CriterionType::CTC || prevHyp.prevBlank ||
n != prevIdx) {
if (!lex->children.empty()) {
if (!isLmToken_) {
lmState = prevHyp.lmState;
lmScore = lex->maxScore - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore,
lmState,
lex.get(),
&prevHyp,
n,
-1,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
}
// If we got a true word
for (auto label : lex->labels) {
if (prevLex == lexicon_->getRoot() && prevHyp.token == n) {
// This is to avoid an situation that, when there is word with
// single token spelling (e.g. X -> x) in the lexicon and token `x`
// is predicted in several consecutive frames, multiple word `X`
// will be emitted. This violates the property of CTC, where
// there must be an blank token in between to predict 2 identical
// tokens consecutively.
continue;
}
if (!isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, label);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore + opt_.wordScore,
lmState,
lexicon_->getRoot(),
&prevHyp,
n,
label,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
// If we got an unknown word
if (lex->labels.empty() && (opt_.unkScore > kNegativeInfinity)) {
if (!isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, unk_);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore + opt_.unkScore,
lmState,
lexicon_->getRoot(),
&prevHyp,
n,
unk_,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
}
/* (2) Try same lexicon node */
if (opt_.criterionType != CriterionType::CTC || !prevHyp.prevBlank ||
prevLex == lexicon_->getRoot()) {
int n = prevLex == lexicon_->getRoot() ? sil_ : prevIdx;
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + amScore;
if (n == sil_) {
score += opt_.silScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
prevLex,
&prevHyp,
n,
-1,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
/* (3) CTC only, try blank */
if (opt_.criterionType == CriterionType::CTC) {
int n = blank_;
double amScore = emissions[t * N + n];
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + amScore,
prevHyp.lmState,
prevLex,
&prevHyp,
n,
-1,
true, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
// finish proposing
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[startFrame + t + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
false);
updateLMCache(lm_, hyp_[startFrame + t + 1]);
}
nDecodedFrames_ += T;
}