void LexiconFreeSeq2SeqDecoder::decodeStep()

in recipes/sota/2019/rescoring/src/LexiconFreeSeq2SeqDecoder.cpp [19:178]


void LexiconFreeSeq2SeqDecoder::decodeStep(
    const float* emissions,
    int T,
    int N) {
  // Extend hyp_ buffer
  if (hyp_.size() < maxOutputLength_ + 2) {
    for (int i = hyp_.size(); i < maxOutputLength_ + 2; i++) {
      hyp_.emplace(i, std::vector<LexiconFreeSeq2SeqDecoderState>());
    }
  }

  // Start from here.
  hyp_[0].clear();
  hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, -1, nullptr);
  completedCandidates_.clear();

  auto hypComparator = [](const LexiconFreeSeq2SeqDecoderState& state1,
                          const LexiconFreeSeq2SeqDecoderState& state2) {
    return state1.score > state2.score;
  };
  // Decode frame by frame
  int t = 0;
  for (; t < maxOutputLength_; t++) {
    candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);

    // Batch forwarding
    rawY_.clear();
    rawPrevStates_.clear();
    for (const LexiconFreeSeq2SeqDecoderState& prevHyp : hyp_[t]) {
      const AMStatePtr& prevState = prevHyp.amState;
      if (prevHyp.token == eos_) {
        continue;
      }
      rawY_.push_back(prevHyp.token);
      rawPrevStates_.push_back(prevState);
    }
    if (rawY_.size() == 0) {
      // all previous hypothesis are completed, add them to the
      // completedCandidates_ before exit the loop
      for (const LexiconFreeSeq2SeqDecoderState& prevHyp : hyp_[t]) {
        completedCandidates_.push_back(prevHyp);
      }
      break;
    }

    std::vector<std::vector<float>> amScores;
    std::vector<AMStatePtr> outStates;

    std::tie(amScores, outStates) =
        amUpdateFunc_(emissions, N, T, rawY_, rawPrevStates_, t);

    std::vector<size_t> idx(amScores.back().size());

    // Generate new hypothesis
    for (int hypo = 0, validHypo = 0; hypo < hyp_[t].size(); hypo++) {
      const LexiconFreeSeq2SeqDecoderState& prevHyp = hyp_[t][hypo];
      // Change nothing for completed hypothesis
      if (prevHyp.token == eos_) {
        // add to pool of completed hyps to avoid thresholding them in the
        // future (only for full beam)
        completedCandidates_.push_back(prevHyp);
        continue;
      }

      const AMStatePtr& outState = outStates[validHypo];
      if (!outState) {
        validHypo++;
        continue;
      }

      std::iota(idx.begin(), idx.end(), 0);
      if (amScores[validHypo].size() > opt_.beamSizeToken) {
        std::partial_sort(
            idx.begin(),
            idx.begin() + opt_.beamSizeToken,
            idx.end(),
            [&amScores, &validHypo](const size_t& l, const size_t& r) {
              return amScores[validHypo][l] > amScores[validHypo][r];
            });
      }

      for (int r = 0;
           r < std::min(amScores[validHypo].size(), (size_t)opt_.beamSizeToken);
           r++) {
        int n = idx[r];
        double amScore = amScores[validHypo][n];

        if (n == eos_) { /* (1) Try eos */
          auto lmStateScorePair = lm_->finish(prevHyp.lmState);
          auto lmScore = lmStateScorePair.second;

          candidatesAdd(
              candidates_,
              candidatesBestScore_,
              opt_.beamThreshold,
              prevHyp.score + amScore + opt_.eosScore + opt_.lmWeight * lmScore,
              lmStateScorePair.first,
              &prevHyp,
              n,
              nullptr,
              prevHyp.amScore + amScore,
              prevHyp.lmScore + lmScore);
        } else { /* (2) Try normal token */
          auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
          auto lmScore = lmStateScorePair.second;
          candidatesAdd(
              candidates_,
              candidatesBestScore_,
              opt_.beamThreshold,
              prevHyp.score + amScore + opt_.lmWeight * lmScore,
              lmStateScorePair.first,
              &prevHyp,
              n,
              outState,
              prevHyp.amScore + amScore,
              prevHyp.lmScore + lmScore);
        }
      }
      validHypo++;
    }
    candidatesStore(
        candidates_,
        candidatePtrs_,
        hyp_[t + 1],
        opt_.beamSize,
        candidatesBestScore_ - opt_.beamThreshold,
        opt_.logAdd,
        true);
    updateLMCache(lm_, hyp_[t + 1]);

    if (completedCandidates_.size() >= opt_.beamSize) {
      std::partial_sort(
          completedCandidates_.begin(),
          completedCandidates_.begin() + opt_.beamSize,
          completedCandidates_.end(),
          hypComparator);
      completedCandidates_.resize(opt_.beamSize);
    }
  } // End of decoding

  std::vector<LexiconFreeSeq2SeqDecoderState> finalCandidates;
  if (completedCandidates_.size() > 0) {
    std::partial_sort(
        completedCandidates_.begin(),
        completedCandidates_.begin() + opt_.beamSize,
        completedCandidates_.end(),
        hypComparator);
    completedCandidates_.resize(opt_.beamSize);
    finalCandidates = completedCandidates_;
  } else {
    while (t > 0 && hyp_[t].empty()) {
      --t;
    }
    finalCandidates = hyp_[t];
  }
  hyp_[maxOutputLength_ + 1].resize(finalCandidates.size());
  for (int i = 0; i < finalCandidates.size(); i++) {
    hyp_[maxOutputLength_ + 1][i] = std::move(finalCandidates[i]);
  }
}