void LexiconDecoder::decodeStep()

in torchaudio/csrc/decoder/src/decoder/LexiconDecoder.cpp [32:229]


void LexiconDecoder::decodeStep(const float* emissions, int T, int N) {
  int startFrame = nDecodedFrames_ - nPrunedFrames_;
  // Extend hyp_ buffer
  if (hyp_.size() < startFrame + T + 2) {
    for (int i = hyp_.size(); i < startFrame + T + 2; i++) {
      hyp_.emplace(i, std::vector<LexiconDecoderState>());
    }
  }

  std::vector<size_t> idx(N);
  for (int t = 0; t < T; t++) {
    std::iota(idx.begin(), idx.end(), 0);
    if (N > opt_.beamSizeToken) {
      std::partial_sort(
          idx.begin(),
          idx.begin() + opt_.beamSizeToken,
          idx.end(),
          [&t, &N, &emissions](const size_t& l, const size_t& r) {
            return emissions[t * N + l] > emissions[t * N + r];
          });
    }

    candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
    for (const LexiconDecoderState& prevHyp : hyp_[startFrame + t]) {
      const TrieNode* prevLex = prevHyp.lex;
      const int prevIdx = prevHyp.token;
      const float lexMaxScore =
          prevLex == lexicon_->getRoot() ? 0 : prevLex->maxScore;

      /* (1) Try children */
      for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) {
        int n = idx[r];
        auto iter = prevLex->children.find(n);
        if (iter == prevLex->children.end()) {
          continue;
        }
        const TrieNodePtr& lex = iter->second;
        double amScore = emissions[t * N + n];
        if (nDecodedFrames_ + t > 0 &&
            opt_.criterionType == CriterionType::ASG) {
          amScore += transitions_[n * N + prevIdx];
        }
        double score = prevHyp.score + amScore;
        if (n == sil_) {
          score += opt_.silScore;
        }

        LMStatePtr lmState;
        double lmScore = 0.;

        if (isLmToken_) {
          auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
          lmState = lmStateScorePair.first;
          lmScore = lmStateScorePair.second;
        }

        // We eat-up a new token
        if (opt_.criterionType != CriterionType::CTC || prevHyp.prevBlank ||
            n != prevIdx) {
          if (!lex->children.empty()) {
            if (!isLmToken_) {
              lmState = prevHyp.lmState;
              lmScore = lex->maxScore - lexMaxScore;
            }
            candidatesAdd(
                candidates_,
                candidatesBestScore_,
                opt_.beamThreshold,
                score + opt_.lmWeight * lmScore,
                lmState,
                lex.get(),
                &prevHyp,
                n,
                -1,
                false, // prevBlank
                prevHyp.amScore + amScore,
                prevHyp.lmScore + lmScore);
          }
        }

        // If we got a true word
        for (auto label : lex->labels) {
          if (prevLex == lexicon_->getRoot() && prevHyp.token == n) {
            // This is to avoid an situation that, when there is word with
            // single token spelling (e.g. X -> x) in the lexicon and token `x`
            // is predicted in several consecutive frames, multiple word `X`
            // will be emitted. This violates the property of CTC, where
            // there must be an blank token in between to predict 2 identical
            // tokens consecutively.
            continue;
          }

          if (!isLmToken_) {
            auto lmStateScorePair = lm_->score(prevHyp.lmState, label);
            lmState = lmStateScorePair.first;
            lmScore = lmStateScorePair.second - lexMaxScore;
          }
          candidatesAdd(
              candidates_,
              candidatesBestScore_,
              opt_.beamThreshold,
              score + opt_.lmWeight * lmScore + opt_.wordScore,
              lmState,
              lexicon_->getRoot(),
              &prevHyp,
              n,
              label,
              false, // prevBlank
              prevHyp.amScore + amScore,
              prevHyp.lmScore + lmScore);
        }

        // If we got an unknown word
        if (lex->labels.empty() && (opt_.unkScore > kNegativeInfinity)) {
          if (!isLmToken_) {
            auto lmStateScorePair = lm_->score(prevHyp.lmState, unk_);
            lmState = lmStateScorePair.first;
            lmScore = lmStateScorePair.second - lexMaxScore;
          }
          candidatesAdd(
              candidates_,
              candidatesBestScore_,
              opt_.beamThreshold,
              score + opt_.lmWeight * lmScore + opt_.unkScore,
              lmState,
              lexicon_->getRoot(),
              &prevHyp,
              n,
              unk_,
              false, // prevBlank
              prevHyp.amScore + amScore,
              prevHyp.lmScore + lmScore);
        }
      }

      /* (2) Try same lexicon node */
      if (opt_.criterionType != CriterionType::CTC || !prevHyp.prevBlank ||
          prevLex == lexicon_->getRoot()) {
        int n = prevLex == lexicon_->getRoot() ? sil_ : prevIdx;
        double amScore = emissions[t * N + n];
        if (nDecodedFrames_ + t > 0 &&
            opt_.criterionType == CriterionType::ASG) {
          amScore += transitions_[n * N + prevIdx];
        }
        double score = prevHyp.score + amScore;
        if (n == sil_) {
          score += opt_.silScore;
        }

        candidatesAdd(
            candidates_,
            candidatesBestScore_,
            opt_.beamThreshold,
            score,
            prevHyp.lmState,
            prevLex,
            &prevHyp,
            n,
            -1,
            false, // prevBlank
            prevHyp.amScore + amScore,
            prevHyp.lmScore);
      }

      /* (3) CTC only, try blank */
      if (opt_.criterionType == CriterionType::CTC) {
        int n = blank_;
        double amScore = emissions[t * N + n];
        candidatesAdd(
            candidates_,
            candidatesBestScore_,
            opt_.beamThreshold,
            prevHyp.score + amScore,
            prevHyp.lmState,
            prevLex,
            &prevHyp,
            n,
            -1,
            true, // prevBlank
            prevHyp.amScore + amScore,
            prevHyp.lmScore);
      }
      // finish proposing
    }

    candidatesStore(
        candidates_,
        candidatePtrs_,
        hyp_[startFrame + t + 1],
        opt_.beamSize,
        candidatesBestScore_ - opt_.beamThreshold,
        opt_.logAdd,
        false);
    updateLMCache(lm_, hyp_[startFrame + t + 1]);
  }

  nDecodedFrames_ += T;
}