void candidatesStore()

in torchaudio/csrc/decoder/src/decoder/Utils.h [65:143]


void candidatesStore(
    std::vector<DecoderState>& candidates,
    std::vector<DecoderState*>& candidatePtrs,
    std::vector<DecoderState>& outputs,
    const int beamSize,
    const double threshold,
    const bool logAdd,
    const bool returnSorted) {
  outputs.clear();
  if (candidates.empty()) {
    return;
  }

  /* 1. Select valid candidates */
  for (auto& candidate : candidates) {
    if (candidate.score >= threshold) {
      candidatePtrs.emplace_back(&candidate);
    }
  }

  /* 2. Merge candidates */
  std::sort(
      candidatePtrs.begin(),
      candidatePtrs.end(),
      [](const DecoderState* node1, const DecoderState* node2) {
        int cmp = node1->compareNoScoreStates(node2);
        return cmp == 0 ? node1->score > node2->score : cmp > 0;
      });

  int nHypAfterMerging = 1;
  for (int i = 1; i < candidatePtrs.size(); i++) {
    if (candidatePtrs[i]->compareNoScoreStates(
            candidatePtrs[nHypAfterMerging - 1]) != 0) {
      // Distinct candidate
      candidatePtrs[nHypAfterMerging] = candidatePtrs[i];
      nHypAfterMerging++;
    } else {
      // Same candidate
      double maxScore = std::max(
          candidatePtrs[nHypAfterMerging - 1]->score, candidatePtrs[i]->score);
      if (logAdd) {
        double minScore = std::min(
            candidatePtrs[nHypAfterMerging - 1]->score,
            candidatePtrs[i]->score);
        candidatePtrs[nHypAfterMerging - 1]->score =
            maxScore + std::log1p(std::exp(minScore - maxScore));
      } else {
        candidatePtrs[nHypAfterMerging - 1]->score = maxScore;
      }
    }
  }
  candidatePtrs.resize(nHypAfterMerging);

  /* 3. Sort and prune */
  auto compareNodeScore = [](const DecoderState* node1,
                             const DecoderState* node2) {
    return node1->score > node2->score;
  };

  int nValidHyp = candidatePtrs.size();
  int finalSize = std::min(nValidHyp, beamSize);
  if (!returnSorted && nValidHyp > beamSize) {
    std::nth_element(
        candidatePtrs.begin(),
        candidatePtrs.begin() + finalSize,
        candidatePtrs.begin() + nValidHyp,
        compareNodeScore);
  } else if (returnSorted) {
    std::partial_sort(
        candidatePtrs.begin(),
        candidatePtrs.begin() + finalSize,
        candidatePtrs.begin() + nValidHyp,
        compareNodeScore);
  }

  for (int i = 0; i < finalSize; i++) {
    outputs.emplace_back(std::move(*candidatePtrs[i]));
  }
}