void GreedyAnalysis()

in tensorflow_text/core/kernels/constrained_sequence.cc [356:433]


void GreedyAnalysis(
    const ScoreAccessor &scores,
    const tensorflow::TTypes<const float>::Matrix &transition_weights,
    const tensorflow::TTypes<const bool>::Matrix &allowed_transitions,
    int batch, bool use_log_space, bool use_start_end_states,
    int32 *output_data) {
  const bool has_transition_weights = transition_weights.size() != 0;
  const bool has_allowed_transitions = allowed_transitions.size() != 0;
  const int num_states = scores.num_scores();
  const int out_of_bounds_index = num_states;
  int64 num_steps = scores.GetLength(batch);

  for (int step = 0; step < num_steps; ++step) {
    // Do final step calculations if this is the final step in the sequence
    // and we are calculating based on implicit start and end states.
    bool do_final_step =
        (step == scores.GetLength(batch) - 1) && use_start_end_states;
    VLOG(2) << "is last step: " << do_final_step;

    const int previous_state =
        (step == 0) ? (out_of_bounds_index) : (output_data[step - 1]);

    if (previous_state == kErrorState) {
      // If the previous state is the error state, the current state must
      // also be the error state.
      output_data[step] = kErrorState;
      continue;
    }

    // If no transition is possible, this will stay the error state.
    int best_new_state = kErrorState;
    float best_new_score = std::numeric_limits<float>::lowest();

    for (int state = 0; state < num_states; ++state) {
      float current_score = scores.GetScore(batch, step, state);

      // If we are not using start/end states AND step is 0, then
      // current_score will not be altered.
      if (use_start_end_states || step > 0) {
        if (has_allowed_transitions) {
          // If either the transition from the previous state to this state
          // is disallowed, or we need to analyze the final step and the
          // transition from this state to the final step is not allowed,
          // disallow this transition.
          if (!allowed_transitions(previous_state, state) ||
              (do_final_step &&
               !allowed_transitions(state, out_of_bounds_index))) {
            continue;
          }
        }

        if (has_transition_weights) {
          if (use_log_space) {
            current_score += transition_weights(previous_state, state);
          } else {
            current_score *= transition_weights(previous_state, state);
          }
          // On the last step, also analyze by the weight value of
          // transitioning from this state to the out-of-bounds state.
          if (do_final_step) {
            if (use_log_space) {
              current_score += transition_weights(state, out_of_bounds_index);
            } else {
              current_score *= transition_weights(state, out_of_bounds_index);
            }
          }
        }
      }
      if (current_score >= best_new_score) {
        best_new_state = state;
        best_new_score = current_score;
      }
    }
    output_data[step] = best_new_state;
    VLOG(2) << "Best state for step " << step << " is " << output_data[step]
            << " with score " << best_new_score;
  }
}