Status ComputeTopK()

in lingvo/core/ops/beam_search_step_op_kernels.cc [110:346]


Status ComputeTopK(int step, const std::vector<Hyp>& hyps, const Tensor& scores,
                   const int32 k, const int32 eos_id, const int32 eoc_id,
                   const int32 num_beams, const float valid_eos_max_logit_delta,
                   const float local_eos_threshold, bool is_first_step,
                   bool is_last_decoder_step, const Tensor& is_last_chunk,
                   bool merge_paths, bool allow_empty_terminated_hyp,
                   bool force_eos_in_top_k, bool force_last_chunk_eoc_in_top_k,
                   int merged_topk_buffer_size_factor,
                   const std::vector<bool>& skip_beam,
                   // Note that this is functionally a bool, however
                   // vector<bool> is not safe to parallel write into
                   // since it's underlying storage is at the byte-level.
                   std::vector<char>* eos_in_topk, std::vector<Hyp>* top_k,
                   std::vector<Hyp>* eos_hyps,
                   std::vector<int32>* terminal_symbols) {
  DCHECK(eos_in_topk && top_k && eos_hyps && terminal_symbols);
  if (hyps.size() != num_beams * k) {
    return tensorflow::errors::Internal(strings::StrCat(
        "Expecting hyps.size()=", num_beams * k, " (num_beams=", num_beams,
        ", k=", k, "), actual hyps.size()=", hyps.size()));
  }
  if (hyps.size() != scores.dim_size(0)) {
    return tensorflow::errors::Internal(strings::StrCat(
        "Expecting scores.shape[0]=", num_beams * k, " (num_beams=", num_beams,
        ", k=", k, "), actual scores.shape[0]=", scores.dim_size(0)));
  }
  if (eos_id >= scores.dim_size(1)) {
    return tensorflow::errors::Internal(strings::StrCat(
        "Expecting eos_id < scores.shape[1]=", scores.dim_size(1),
        ", actual eos_id=", eos_id));
  }

  VLOG(1) << "Topk clear, num_beams: " << num_beams;
  int hyps_size = hyps.size();
  eos_in_topk->clear();
  top_k->clear();
  top_k->resize(hyps_size);
  eos_in_topk->resize(hyps_size);
  eos_hyps->resize(hyps_size);
  terminal_symbols->resize(hyps_size);
  static thread::ThreadPool* workers =
      new thread::ThreadPool(Env::Default(), "topk", kNumWorkers);
  const int num_ids = scores.dim_size(1);
  const auto scores_matrix = scores.matrix<float>();
  const int epsilon_id_for_path_merging = merge_paths ? eoc_id : -1;
  std::vector<
      TopK<Hyp, HigherScore, ExtractGlobalScore, InsertHypWithEpsilonDedupe>>
      merged_topk_vec(num_beams, TopK<Hyp, HigherScore, ExtractGlobalScore,
                                      InsertHypWithEpsilonDedupe>(
                                     k, epsilon_id_for_path_merging,
                                     merged_topk_buffer_size_factor));
  // Each mutex is used to protect corresponding merged_topk_vec.
  std::vector<mutex> mu_vec(num_beams);
  tensorflow::Status status = Status::OK();
  mutex mu_status;
  // The thread sharding is along the hyps_size.
  Shard(
      kNumWorkers, workers, hyps_size, num_ids,
      [&](int64_t start, int64_t limit) {
        for (int32 hyp_id = start; hyp_id < limit; ++hyp_id) {
          if (is_first_step && hyp_id >= num_beams) {
            // For first step, we only consider the first hyp of each beam, as
            // otherwise we will be continuing k identical hyps along the way.
            continue;
          }
          if (skip_beam[hyp_id % num_beams]) {
            continue;
          }
          // +1 to make sure that at least top-k hypotheses survive even with
          // the special treatment for eos.  +2 if we are also using eoc.
          const int topk_size = k + 1 + static_cast<int>(eoc_id >= 0);
          TopK<Hyp, HigherScoreWithEos, ExtractGlobalScore,
               InsertHypWithEpsilonDedupe>
              topk(topk_size, epsilon_id_for_path_merging, eos_id,
                   is_last_decoder_step);
          float bottom_of_topk = -INFINITY;
          int32 id = 0;
          const float current_global_score = hyps[hyp_id].global_score;
      // TODO(xbing): Try AVX512 if it is supported by machine.
#ifdef __AVX__
          const int STRIDE =
              sizeof(__m256) /
              sizeof(std::result_of<decltype(scores_matrix)(int, int)>::type);
          // We read STRIDE float values at a single iteration and compare
          // them with this k-th best value. STRIDE - 1 not to read outside
          // the row.
          for (; id + STRIDE - 1 < num_ids; id += STRIDE) {
            if (!all_less_than(&scores_matrix(hyp_id, id),
                               bottom_of_topk - current_global_score)) {
              for (int i = 0; i < STRIDE; ++i) {
                const float score = scores_matrix(hyp_id, id + i);
                const float global_score =
                    current_global_score + score;
                if (global_score >= bottom_of_topk) {
                  bottom_of_topk =
                      topk.Add({hyps[hyp_id].beam_id, hyp_id, id + i, score,
                                global_score, hyps[hyp_id].prev_labels});
                }
              }
            }
          }
      // Non-AVX code below handles the remaining elements.
#endif
          for (; id != num_ids; ++id) {
            const float score = scores_matrix(hyp_id, id);
            const float global_score = current_global_score + score;
            if (global_score >= bottom_of_topk) {
              bottom_of_topk =
                  topk.Add({hyps[hyp_id].beam_id, hyp_id, id, score,
                            global_score, hyps[hyp_id].prev_labels});
            }
          }

          std::vector<Hyp> entries = topk.Get();
          if (entries.empty()) {
            mutex_lock l(mu_status);
            // This happens when global_score is NaN, hence topk.Add() is never
            // called above, as all comparisons against NaNs are false.
            status = tensorflow::errors::Internal(
                "No entries in TopK. This typically happens if the model is "
                "producing NaNs in the output.");
            return;
          }
          std::sort(entries.begin(), entries.end(), HigherScore());
          if (force_eos_in_top_k) {
            if (std::find_if(entries.begin(), entries.end(),
                             [=](const Hyp& hyp) {
                               return hyp.word_id == eos_id;
                             }) == entries.end()) {
              entries.pop_back();
              const float eos_score = scores_matrix(hyp_id, eos_id);
              entries.push_back({hyps[hyp_id].beam_id, hyp_id, eos_id,
                                 eos_score, current_global_score + eos_score,
                                 hyps[hyp_id].prev_labels});
            }
          }
          if (force_last_chunk_eoc_in_top_k && eoc_id >= 0 &&
              is_last_chunk.vec<bool>()(hyp_id) &&
              (std::find_if(entries.begin(), entries.end(),
                            [=](const Hyp& hyp) {return hyp.word_id == eoc_id;})
               == entries.end())) {
            Hyp last_hyp = Hyp(entries.back());
            entries.pop_back();
            entries.pop_back();
            const float eoc_score = scores_matrix(hyp_id, eoc_id);
            // Forced last chunk eoc is located in the second last position.
            // We choose to overwrite the second last position instead of the
            // very last one as the latter may already have been overwritten
            // due to force_eos_in_top_k.
            // Also note when eoc_id >= 0, we have reserved two additional
            // positions with topk_size, one for eos and one for eoc. So we
            // can afford to overwrite a different position for eoc than eos.
            entries.push_back({hyps[hyp_id].beam_id, hyp_id, eoc_id,
                               eoc_score, current_global_score + eoc_score,
                               hyps[hyp_id].prev_labels});
            entries.push_back(last_hyp);
          }
          const float eos_score_threshold =
              entries[0].global_score - valid_eos_max_logit_delta;
          VLOG(3) << "Best_score=" << entries[0].global_score
                  << " eos_score_threshold=" << eos_score_threshold;
          {
            const int beam_id = hyps[hyp_id].beam_id;
            mutex_lock l(mu_vec[beam_id]);
            for (const auto& e : entries) {
              VLOG(3) << "Extension for beam_id=" << beam_id
                      << ", hyp_id=" << hyp_id
                      << ": global_score=" << e.global_score
                      << ", local_score=" << e.local_score
                      << ", toks=[" << str_util::Join(e.prev_labels, " ")
                      << "], proposing token " << e.word_id;
              if (e.word_id == eos_id) {
                VLOG(3) << "EOS hyp: global_score=" << e.global_score
                        << ", local_score=" << e.local_score
                        << ", toks=[" << str_util::Join(e.prev_labels, " ")
                        << "]";
                // We move terminated hyps off of the beam.
                if (is_last_decoder_step ||
                    (e.global_score > eos_score_threshold &&
                    e.local_score > local_eos_threshold)) {
                  (*eos_in_topk)[hyp_id] = true;
                  (*eos_hyps)[hyp_id] = e;
                  (*terminal_symbols)[hyp_id] = eos_id;
                }
              } else if (eoc_id >= 0 && is_last_chunk.vec<bool>()(hyp_id) &&
                         e.word_id == eoc_id) {
                // At the last chunk and output <epsilon>. We terminate the
                // hypothesis, even though <eos> was not predicted, and
                // indicate that the final symbol for the hypothesis is
                // <epsilon>, not <eos>.
                if (e.global_score > eos_score_threshold &&
                    e.local_score > local_eos_threshold &&
                    // Only allow an empty hyp (all <epsilon>s) to be
                    // considered terminated, if explicitly permitted.
                    // 'prev_labels' contains only non-epsilons.
                    (allow_empty_terminated_hyp || !e.prev_labels.empty())) {
                  VLOG(3) << "Last chunk EOC hyp: global_score="
                          << e.global_score << ", local_score=" << e.local_score
                          << ", toks=[" << str_util::Join(e.prev_labels, " ")
                          << "]";
                  (*eos_in_topk)[hyp_id] = true;
                  (*eos_hyps)[hyp_id] = e;
                  (*terminal_symbols)[hyp_id] = eoc_id;
                }
              } else {
                merged_topk_vec[beam_id].Add(e);
              }
            }
          }
        }
      });
  if (!status.ok()) {
    return status;
  }

  const int hyps_per_beam = k;
  for (int i = 0; i < num_beams; ++i) {
    if (skip_beam[i]) {
      continue;
    }
    auto ith_topk = merged_topk_vec[i].Get();
    std::sort(ith_topk.begin(), ith_topk.end(), HigherScore());
    const int num_hyps =
        std::min(static_cast<int>(ith_topk.size()), hyps_per_beam);
    VLOG(3) << "Active hyps for beam_id=" << i;
    for (int j = 0; j < num_hyps; ++j) {
      (*top_k)[j * num_beams + i] = ith_topk[j];
      VLOG(3) << "Active hyp " << j
              << ", global_score=" << ith_topk[j].global_score
              << ", local score=" << ith_topk[j].local_score
              << ", toks=[" << str_util::Join(ith_topk[j].prev_labels, " ")
              << "]";
    }
  }
  VLOG(1) << "Topk done";
  return Status::OK();
}