in tensorflow_text/core/kernels/constrained_sequence.cc [84:354]
void ViterbiAnalysis(
const ScoreAccessor &scores,
const tensorflow::TTypes<const float>::Matrix &transition_weights,
const tensorflow::TTypes<const bool>::Matrix &allowed_transitions,
const int batch, bool use_log_space, bool use_start_end_states,
int32 *output_data) {
VLOG(2) << "Analyzing batch " << batch;
const bool has_transition_weights = transition_weights.size() != 0;
const bool has_allowed_transitions = allowed_transitions.size() != 0;
const int num_states = scores.num_scores();
const int out_of_bounds_index = num_states;
int64 num_steps = scores.GetLength(batch);
// Create two vectors to hold scores. These will be bound to referents later
// so the names here are somewhat irrelevant.
std::vector<double> scores_a(num_states,
std::numeric_limits<float>::lowest());
std::vector<double> scores_b(num_states,
std::numeric_limits<float>::lowest());
// Create a chart of backpointers. Include rows for [start] and [end]
// transitions. By initializing this to kErrorState, we ensure unreachable
// transitions get marked as errors.
std::vector<std::vector<int>> backpointers(
num_steps, std::vector<int>(num_states, kErrorState));
// Set current and previous references for step 0
std::vector<double> *previous_scores = &scores_a;
std::vector<double> *current_scores = &scores_b;
const bool vlog3 = VLOG_IS_ON(3);
for (int curr_state = 0; curr_state < num_states; ++curr_state) {
std::vector<int> ¤t_bps = backpointers[0];
if (use_start_end_states) {
// Initialize the zeroth step BPs to kOutOfBoundsIndex for all states
// where the OOB->state transition is valid, and set scores as needed.
if (has_allowed_transitions &&
!allowed_transitions(out_of_bounds_index, curr_state)) {
if (vlog3) {
LOG(INFO) << "(" << batch << ", 0, [START]->" << curr_state
<< "): disallowed.";
}
continue;
}
// Because the backpointer vectors are initialized to kErrorState, we
// need only to set the valid transition paths to have come from the
// padding state.
current_bps[curr_state] = out_of_bounds_index;
// For valid transitions, get the score (and adjust as appropriate).
const int step = 0;
float current_score = scores.GetScore(batch, step, curr_state);
if (has_transition_weights) {
if (use_log_space) {
current_score += transition_weights(out_of_bounds_index, curr_state);
} else {
current_score *= transition_weights(out_of_bounds_index, curr_state);
}
}
if (vlog3) {
if (has_transition_weights) {
LOG(INFO) << "(" << batch << ", " << step << ", [START]->"
<< curr_state << "): Total score: " << current_score
<< " (raw: " << scores.GetScore(batch, step, curr_state)
<< ", tw: "
<< transition_weights(out_of_bounds_index, curr_state)
<< ")";
} else {
LOG(INFO) << "(" << batch << ", " << step << ", [START]->"
<< curr_state << "): Total score: " << current_score
<< " (raw: " << scores.GetScore(batch, step, curr_state)
<< ")";
}
}
current_scores->at(curr_state) = current_score;
} else {
// If we don't have specific start and end states, all bp's are valid
// and all starting scores are the unadjusted step 0 scores.
current_bps[curr_state] = out_of_bounds_index;
const int step = 0;
current_scores->at(curr_state) = scores.GetScore(batch, step, curr_state);
}
}
// Update the current scores (and normalize if we're not in log space).
if (!use_log_space) {
const double max_score =
*std::max_element(current_scores->begin(), current_scores->end());
if (max_score > 0) {
for (double &score : *current_scores) score /= max_score;
}
}
// Swap current and previous score arrays, as we are advancing a step.
std::vector<double> *tmp = previous_scores;
previous_scores = current_scores;
current_scores = tmp;
// Handle all steps save for the first and last in this loop.
for (int step = 1; step < num_steps; ++step) {
const std::vector<int> &previous_bps = backpointers[step - 1];
std::vector<int> ¤t_bps = backpointers[step];
for (int curr_state = 0; curr_state < num_states; ++curr_state) {
int best_source_state = kErrorState;
float best_score = std::numeric_limits<float>::lowest();
for (int prev_state = 0; prev_state < num_states; ++prev_state) {
// If the previous state was an error state, pass to the next state.
if (previous_bps[prev_state] == kErrorState) {
if (vlog3) {
LOG(INFO) << "(" << batch << ", " << step << ", " << prev_state
<< "->" << curr_state << "): prev state error.";
}
continue;
}
// If this is not a permitted transition, continue.
if (has_allowed_transitions &&
!allowed_transitions(prev_state, curr_state)) {
if (vlog3) {
LOG(INFO) << "(" << batch << ", " << step << ", " << prev_state
<< "->" << curr_state << "): disallowed.";
}
continue;
}
float current_score = scores.GetScore(batch, step, curr_state);
if (use_log_space) {
current_score += previous_scores->at(prev_state);
} else {
current_score *= previous_scores->at(prev_state);
}
if (has_transition_weights) {
if (use_log_space) {
current_score += transition_weights(prev_state, curr_state);
} else {
current_score *= transition_weights(prev_state, curr_state);
}
}
if (vlog3) {
if (has_transition_weights) {
LOG(INFO) << "(" << batch << ", " << step << ", " << prev_state
<< "->" << curr_state
<< "): Total score: " << current_score
<< " (prev: " << previous_scores->at(prev_state)
<< ", raw: " << scores.GetScore(batch, step, curr_state)
<< ", tw: " << transition_weights(prev_state, curr_state)
<< ")";
} else {
LOG(INFO) << "(" << batch << ", " << step << ", " << prev_state
<< "->" << curr_state
<< "): Total score: " << current_score
<< " (prev: " << previous_scores->at(prev_state)
<< ", raw: " << scores.GetScore(batch, step, curr_state)
<< ")";
}
}
if (current_score >= best_score) {
best_source_state = prev_state;
best_score = current_score;
}
}
current_bps[curr_state] = best_source_state;
current_scores->at(curr_state) = best_score;
}
// Normalize if we're not in log space.
if (!use_log_space) {
const double max_score =
*std::max_element(current_scores->begin(), current_scores->end());
if (max_score > 0) {
for (double &score : *current_scores) score /= max_score;
}
}
// After each step, switch the current scores to the previous scores and
// use the previous previous scores as the current scores.
std::vector<double> *tmp = previous_scores;
previous_scores = current_scores;
current_scores = tmp;
}
// Handle the final transition out of the sequence.
int final_state = out_of_bounds_index;
const std::vector<int> &previous_bps = backpointers[num_steps - 1];
int best_source_state = kErrorState;
float final_score = std::numeric_limits<float>::lowest();
for (int prev_state = 0; prev_state < num_states; ++prev_state) {
// If the previous state was an error state, pass to the next state.
if (previous_bps[prev_state] == kErrorState) {
current_scores->at(prev_state) = std::numeric_limits<float>::lowest();
if (vlog3) {
LOG(INFO) << "(" << batch << ", " << num_steps << ", " << prev_state
<< "->[END]): prev state error.";
}
continue;
}
// If this is not a permitted transition, continue.
if (has_allowed_transitions && use_start_end_states &&
!allowed_transitions(prev_state, final_state)) {
current_scores->at(prev_state) = std::numeric_limits<float>::lowest();
if (vlog3) {
LOG(INFO) << "(" << batch << ", " << num_steps << ", " << prev_state
<< "->[END]): disallowed.";
}
continue;
}
// Weight the final transition score by the probability of exiting the
// sequence as well.
float current_score = previous_scores->at(prev_state);
if (use_start_end_states) {
if (has_transition_weights) {
if (use_log_space) {
current_score += transition_weights(prev_state, final_state);
} else {
current_score *= transition_weights(prev_state, final_state);
}
}
if (vlog3) {
if (has_transition_weights) {
LOG(INFO) << "(" << batch << ", " << num_steps << ", " << prev_state
<< "->[END]): Total score: " << current_score
<< " (prev: " << previous_scores->at(prev_state)
<< ", tw: " << transition_weights(prev_state, final_state)
<< ")";
} else {
LOG(INFO) << "(" << batch << ", " << num_steps << ", " << prev_state
<< "->[END]): Total score: " << current_score
<< " (prev: " << previous_scores->at(prev_state) << ")";
}
}
}
current_scores->at(prev_state) = current_score;
if (current_score >= final_score) {
best_source_state = prev_state;
final_score = current_score;
}
}
if (vlog3) {
LOG(INFO) << "Final score: " << final_score;
}
// Calculate the path.
if (best_source_state == kErrorState) {
// If the best source is an error state, the path is unknowable. Report
// error states for the whole sequence.
for (int64 i = 0; i < scores.GetLength(batch); ++i) {
output_data[i] = kErrorState;
}
} else {
// If the best source is a 'real' state, report the state path.
int steps_to_report = scores.GetLength(batch);
int previous_state = best_source_state;
for (int64 i = steps_to_report - 1; i >= 0; --i) {
output_data[i] = previous_state;
previous_state = backpointers[i][previous_state];
}
}
}