in lingvo/core/ops/beam_search_step_op_kernels.cc [110:346]
Status ComputeTopK(int step, const std::vector<Hyp>& hyps, const Tensor& scores,
const int32 k, const int32 eos_id, const int32 eoc_id,
const int32 num_beams, const float valid_eos_max_logit_delta,
const float local_eos_threshold, bool is_first_step,
bool is_last_decoder_step, const Tensor& is_last_chunk,
bool merge_paths, bool allow_empty_terminated_hyp,
bool force_eos_in_top_k, bool force_last_chunk_eoc_in_top_k,
int merged_topk_buffer_size_factor,
const std::vector<bool>& skip_beam,
// Note that this is functionally a bool, however
// vector<bool> is not safe to parallel write into
// since it's underlying storage is at the byte-level.
std::vector<char>* eos_in_topk, std::vector<Hyp>* top_k,
std::vector<Hyp>* eos_hyps,
std::vector<int32>* terminal_symbols) {
DCHECK(eos_in_topk && top_k && eos_hyps && terminal_symbols);
if (hyps.size() != num_beams * k) {
return tensorflow::errors::Internal(strings::StrCat(
"Expecting hyps.size()=", num_beams * k, " (num_beams=", num_beams,
", k=", k, "), actual hyps.size()=", hyps.size()));
}
if (hyps.size() != scores.dim_size(0)) {
return tensorflow::errors::Internal(strings::StrCat(
"Expecting scores.shape[0]=", num_beams * k, " (num_beams=", num_beams,
", k=", k, "), actual scores.shape[0]=", scores.dim_size(0)));
}
if (eos_id >= scores.dim_size(1)) {
return tensorflow::errors::Internal(strings::StrCat(
"Expecting eos_id < scores.shape[1]=", scores.dim_size(1),
", actual eos_id=", eos_id));
}
VLOG(1) << "Topk clear, num_beams: " << num_beams;
int hyps_size = hyps.size();
eos_in_topk->clear();
top_k->clear();
top_k->resize(hyps_size);
eos_in_topk->resize(hyps_size);
eos_hyps->resize(hyps_size);
terminal_symbols->resize(hyps_size);
static thread::ThreadPool* workers =
new thread::ThreadPool(Env::Default(), "topk", kNumWorkers);
const int num_ids = scores.dim_size(1);
const auto scores_matrix = scores.matrix<float>();
const int epsilon_id_for_path_merging = merge_paths ? eoc_id : -1;
std::vector<
TopK<Hyp, HigherScore, ExtractGlobalScore, InsertHypWithEpsilonDedupe>>
merged_topk_vec(num_beams, TopK<Hyp, HigherScore, ExtractGlobalScore,
InsertHypWithEpsilonDedupe>(
k, epsilon_id_for_path_merging,
merged_topk_buffer_size_factor));
// Each mutex is used to protect corresponding merged_topk_vec.
std::vector<mutex> mu_vec(num_beams);
tensorflow::Status status = Status::OK();
mutex mu_status;
// The thread sharding is along the hyps_size.
Shard(
kNumWorkers, workers, hyps_size, num_ids,
[&](int64_t start, int64_t limit) {
for (int32 hyp_id = start; hyp_id < limit; ++hyp_id) {
if (is_first_step && hyp_id >= num_beams) {
// For first step, we only consider the first hyp of each beam, as
// otherwise we will be continuing k identical hyps along the way.
continue;
}
if (skip_beam[hyp_id % num_beams]) {
continue;
}
// +1 to make sure that at least top-k hypotheses survive even with
// the special treatment for eos. +2 if we are also using eoc.
const int topk_size = k + 1 + static_cast<int>(eoc_id >= 0);
TopK<Hyp, HigherScoreWithEos, ExtractGlobalScore,
InsertHypWithEpsilonDedupe>
topk(topk_size, epsilon_id_for_path_merging, eos_id,
is_last_decoder_step);
float bottom_of_topk = -INFINITY;
int32 id = 0;
const float current_global_score = hyps[hyp_id].global_score;
// TODO(xbing): Try AVX512 if it is supported by machine.
#ifdef __AVX__
const int STRIDE =
sizeof(__m256) /
sizeof(std::result_of<decltype(scores_matrix)(int, int)>::type);
// We read STRIDE float values at a single iteration and compare
// them with this k-th best value. STRIDE - 1 not to read outside
// the row.
for (; id + STRIDE - 1 < num_ids; id += STRIDE) {
if (!all_less_than(&scores_matrix(hyp_id, id),
bottom_of_topk - current_global_score)) {
for (int i = 0; i < STRIDE; ++i) {
const float score = scores_matrix(hyp_id, id + i);
const float global_score =
current_global_score + score;
if (global_score >= bottom_of_topk) {
bottom_of_topk =
topk.Add({hyps[hyp_id].beam_id, hyp_id, id + i, score,
global_score, hyps[hyp_id].prev_labels});
}
}
}
}
// Non-AVX code below handles the remaining elements.
#endif
for (; id != num_ids; ++id) {
const float score = scores_matrix(hyp_id, id);
const float global_score = current_global_score + score;
if (global_score >= bottom_of_topk) {
bottom_of_topk =
topk.Add({hyps[hyp_id].beam_id, hyp_id, id, score,
global_score, hyps[hyp_id].prev_labels});
}
}
std::vector<Hyp> entries = topk.Get();
if (entries.empty()) {
mutex_lock l(mu_status);
// This happens when global_score is NaN, hence topk.Add() is never
// called above, as all comparisons against NaNs are false.
status = tensorflow::errors::Internal(
"No entries in TopK. This typically happens if the model is "
"producing NaNs in the output.");
return;
}
std::sort(entries.begin(), entries.end(), HigherScore());
if (force_eos_in_top_k) {
if (std::find_if(entries.begin(), entries.end(),
[=](const Hyp& hyp) {
return hyp.word_id == eos_id;
}) == entries.end()) {
entries.pop_back();
const float eos_score = scores_matrix(hyp_id, eos_id);
entries.push_back({hyps[hyp_id].beam_id, hyp_id, eos_id,
eos_score, current_global_score + eos_score,
hyps[hyp_id].prev_labels});
}
}
if (force_last_chunk_eoc_in_top_k && eoc_id >= 0 &&
is_last_chunk.vec<bool>()(hyp_id) &&
(std::find_if(entries.begin(), entries.end(),
[=](const Hyp& hyp) {return hyp.word_id == eoc_id;})
== entries.end())) {
Hyp last_hyp = Hyp(entries.back());
entries.pop_back();
entries.pop_back();
const float eoc_score = scores_matrix(hyp_id, eoc_id);
// Forced last chunk eoc is located in the second last position.
// We choose to overwrite the second last position instead of the
// very last one as the latter may already have been overwritten
// due to force_eos_in_top_k.
// Also note when eoc_id >= 0, we have reserved two additional
// positions with topk_size, one for eos and one for eoc. So we
// can afford to overwrite a different position for eoc than eos.
entries.push_back({hyps[hyp_id].beam_id, hyp_id, eoc_id,
eoc_score, current_global_score + eoc_score,
hyps[hyp_id].prev_labels});
entries.push_back(last_hyp);
}
const float eos_score_threshold =
entries[0].global_score - valid_eos_max_logit_delta;
VLOG(3) << "Best_score=" << entries[0].global_score
<< " eos_score_threshold=" << eos_score_threshold;
{
const int beam_id = hyps[hyp_id].beam_id;
mutex_lock l(mu_vec[beam_id]);
for (const auto& e : entries) {
VLOG(3) << "Extension for beam_id=" << beam_id
<< ", hyp_id=" << hyp_id
<< ": global_score=" << e.global_score
<< ", local_score=" << e.local_score
<< ", toks=[" << str_util::Join(e.prev_labels, " ")
<< "], proposing token " << e.word_id;
if (e.word_id == eos_id) {
VLOG(3) << "EOS hyp: global_score=" << e.global_score
<< ", local_score=" << e.local_score
<< ", toks=[" << str_util::Join(e.prev_labels, " ")
<< "]";
// We move terminated hyps off of the beam.
if (is_last_decoder_step ||
(e.global_score > eos_score_threshold &&
e.local_score > local_eos_threshold)) {
(*eos_in_topk)[hyp_id] = true;
(*eos_hyps)[hyp_id] = e;
(*terminal_symbols)[hyp_id] = eos_id;
}
} else if (eoc_id >= 0 && is_last_chunk.vec<bool>()(hyp_id) &&
e.word_id == eoc_id) {
// At the last chunk and output <epsilon>. We terminate the
// hypothesis, even though <eos> was not predicted, and
// indicate that the final symbol for the hypothesis is
// <epsilon>, not <eos>.
if (e.global_score > eos_score_threshold &&
e.local_score > local_eos_threshold &&
// Only allow an empty hyp (all <epsilon>s) to be
// considered terminated, if explicitly permitted.
// 'prev_labels' contains only non-epsilons.
(allow_empty_terminated_hyp || !e.prev_labels.empty())) {
VLOG(3) << "Last chunk EOC hyp: global_score="
<< e.global_score << ", local_score=" << e.local_score
<< ", toks=[" << str_util::Join(e.prev_labels, " ")
<< "]";
(*eos_in_topk)[hyp_id] = true;
(*eos_hyps)[hyp_id] = e;
(*terminal_symbols)[hyp_id] = eoc_id;
}
} else {
merged_topk_vec[beam_id].Add(e);
}
}
}
}
});
if (!status.ok()) {
return status;
}
const int hyps_per_beam = k;
for (int i = 0; i < num_beams; ++i) {
if (skip_beam[i]) {
continue;
}
auto ith_topk = merged_topk_vec[i].Get();
std::sort(ith_topk.begin(), ith_topk.end(), HigherScore());
const int num_hyps =
std::min(static_cast<int>(ith_topk.size()), hyps_per_beam);
VLOG(3) << "Active hyps for beam_id=" << i;
for (int j = 0; j < num_hyps; ++j) {
(*top_k)[j * num_beams + i] = ith_topk[j];
VLOG(3) << "Active hyp " << j
<< ", global_score=" << ith_topk[j].global_score
<< ", local score=" << ith_topk[j].local_score
<< ", toks=[" << str_util::Join(ith_topk[j].prev_labels, " ")
<< "]";
}
}
VLOG(1) << "Topk done";
return Status::OK();
}