void Compute()

in tensorflow_text/core/kernels/constrained_sequence_kernel.cc [97:205]


  void Compute(OpKernelContext *context) override {
    const auto &score_tensor = context->input(0);
    OP_REQUIRES(context,
                (score_tensor.shape().dims() == 2) ||
                    (score_tensor.shape().dims() == 3),
                InvalidArgument("The score tensor must be of rank 2 or 3."));
    const auto &lengths_tensor = context->input(1);

    ScoreAccessor scores(score_tensor, lengths_tensor);

    // The scores tensor should be [batch, step, scores].
    const int batch_size = scores.batch_size();
    const int num_steps = scores.num_steps();
    const int num_scores = scores.num_scores();

    OP_REQUIRES(context, lengths_tensor.NumElements() == batch_size,
                InvalidArgument(tensorflow::strings::StrCat(
                    "There should be exactly one length for every batch "
                    "element. Found ",
                    lengths_tensor.NumElements(),
                    " length elements for a batch size of ", batch_size)));

    VLOG(2) << "batch: " << batch_size;
    VLOG(2) << "steps: " << num_steps;
    VLOG(2) << "score: " << num_scores;

    // Make sure there's enough data to advance every sequence.
    int max_length = 0;
    int total_length = 0;
    for (int i = 0; i < batch_size; ++i) {
      int64 length = scores.GetLength(i);
      total_length += length;
      if (length > max_length) {
        max_length = length;
      }
    }

    OP_REQUIRES(
        context, num_steps >= max_length,
        InvalidArgument(
            "The scores tensor is too short for the longest sequence length."));

    // Validate the constraint tensors.
    const auto &allowed_transitions_tensor = context->input(2);
    bool has_allowed_transitions =
        allowed_transitions_tensor.NumElements() != 0;
    VLOG(4) << allowed_transitions_tensor.NumElements();
    if (has_allowed_transitions) {
      OP_REQUIRES_OK(context,
                     ValidateConstraintTensor(allowed_transitions_tensor,
                                              num_scores, use_start_end_states_,
                                              "allowed_transitions"));
    }

    const auto &transition_weights_tensor = context->input(3);

    VLOG(4) << transition_weights_tensor.NumElements();
    bool has_transition_weights = transition_weights_tensor.NumElements() != 0;
    if (has_transition_weights) {
      OP_REQUIRES_OK(context, ValidateConstraintTensor(
                                  transition_weights_tensor, num_scores,
                                  use_start_end_states_, "transition_weights"));

      // If we have transition weights in exp-space, all values must be non-
      // negative.
      if (!use_log_space_) {
        for (int i = 0; i < transition_weights_tensor.NumElements(); ++i) {
          OP_REQUIRES(context, transition_weights_tensor.flat<float>()(i) >= 0,
                      InvalidArgument("The transition weights tensor must not "
                                      "contain negative values."));
        }
      }
    }

    const tensorflow::Tensor empty_float(DT_FLOAT, TensorShape({0, 0}));
    const tensorflow::Tensor empty_bool(DT_BOOL, TensorShape({0, 0}));

    const auto &transition_weights =
        has_transition_weights ? transition_weights_tensor.matrix<float>()
                               : empty_float.matrix<float>();

    const auto &allowed_transitions =
        has_allowed_transitions ? allowed_transitions_tensor.matrix<bool>()
                                : empty_bool.matrix<bool>();

    Tensor *output;
    OP_REQUIRES_OK(context, context->allocate_output(
                                0, TensorShape({total_length}), &output));
    int32 *output_data = output->flat<int32>().data();

    Tensor *offsets;
    OP_REQUIRES_OK(context, context->allocate_output(
                                1, TensorShape({batch_size + 1}), &offsets));
    Tsplits *offset_data = offsets->flat<Tsplits>().data();
    offset_data[0] = 0;

    for (int batch = 0; batch < batch_size; ++batch) {
      int step_offset = offset_data[batch];
      int64 num_steps = scores.GetLength(batch);
      offset_data[batch + 1] = step_offset + num_steps;
      if (use_viterbi_) {
        DoViterbiAnalysis(transition_weights, allowed_transitions, batch,
                          scores, &output_data[step_offset]);
      } else {
        DoGreedyAnalysis(transition_weights, allowed_transitions, batch, scores,
                         &output_data[step_offset]);
      }
    }
  }