in tensorflow_text/core/kernels/constrained_sequence.cc [356:433]
void GreedyAnalysis(
const ScoreAccessor &scores,
const tensorflow::TTypes<const float>::Matrix &transition_weights,
const tensorflow::TTypes<const bool>::Matrix &allowed_transitions,
int batch, bool use_log_space, bool use_start_end_states,
int32 *output_data) {
const bool has_transition_weights = transition_weights.size() != 0;
const bool has_allowed_transitions = allowed_transitions.size() != 0;
const int num_states = scores.num_scores();
const int out_of_bounds_index = num_states;
int64 num_steps = scores.GetLength(batch);
for (int step = 0; step < num_steps; ++step) {
// Do final step calculations if this is the final step in the sequence
// and we are calculating based on implicit start and end states.
bool do_final_step =
(step == scores.GetLength(batch) - 1) && use_start_end_states;
VLOG(2) << "is last step: " << do_final_step;
const int previous_state =
(step == 0) ? (out_of_bounds_index) : (output_data[step - 1]);
if (previous_state == kErrorState) {
// If the previous state is the error state, the current state must
// also be the error state.
output_data[step] = kErrorState;
continue;
}
// If no transition is possible, this will stay the error state.
int best_new_state = kErrorState;
float best_new_score = std::numeric_limits<float>::lowest();
for (int state = 0; state < num_states; ++state) {
float current_score = scores.GetScore(batch, step, state);
// If we are not using start/end states AND step is 0, then
// current_score will not be altered.
if (use_start_end_states || step > 0) {
if (has_allowed_transitions) {
// If either the transition from the previous state to this state
// is disallowed, or we need to analyze the final step and the
// transition from this state to the final step is not allowed,
// disallow this transition.
if (!allowed_transitions(previous_state, state) ||
(do_final_step &&
!allowed_transitions(state, out_of_bounds_index))) {
continue;
}
}
if (has_transition_weights) {
if (use_log_space) {
current_score += transition_weights(previous_state, state);
} else {
current_score *= transition_weights(previous_state, state);
}
// On the last step, also analyze by the weight value of
// transitioning from this state to the out-of-bounds state.
if (do_final_step) {
if (use_log_space) {
current_score += transition_weights(state, out_of_bounds_index);
} else {
current_score *= transition_weights(state, out_of_bounds_index);
}
}
}
}
if (current_score >= best_new_score) {
best_new_state = state;
best_new_score = current_score;
}
}
output_data[step] = best_new_state;
VLOG(2) << "Best state for step " << step << " is " << output_data[step]
<< " with score " << best_new_score;
}
}