in tensorflow_text/core/kernels/constrained_sequence_kernel.cc [97:205]
void Compute(OpKernelContext *context) override {
const auto &score_tensor = context->input(0);
OP_REQUIRES(context,
(score_tensor.shape().dims() == 2) ||
(score_tensor.shape().dims() == 3),
InvalidArgument("The score tensor must be of rank 2 or 3."));
const auto &lengths_tensor = context->input(1);
ScoreAccessor scores(score_tensor, lengths_tensor);
// The scores tensor should be [batch, step, scores].
const int batch_size = scores.batch_size();
const int num_steps = scores.num_steps();
const int num_scores = scores.num_scores();
OP_REQUIRES(context, lengths_tensor.NumElements() == batch_size,
InvalidArgument(tensorflow::strings::StrCat(
"There should be exactly one length for every batch "
"element. Found ",
lengths_tensor.NumElements(),
" length elements for a batch size of ", batch_size)));
VLOG(2) << "batch: " << batch_size;
VLOG(2) << "steps: " << num_steps;
VLOG(2) << "score: " << num_scores;
// Make sure there's enough data to advance every sequence.
int max_length = 0;
int total_length = 0;
for (int i = 0; i < batch_size; ++i) {
int64 length = scores.GetLength(i);
total_length += length;
if (length > max_length) {
max_length = length;
}
}
OP_REQUIRES(
context, num_steps >= max_length,
InvalidArgument(
"The scores tensor is too short for the longest sequence length."));
// Validate the constraint tensors.
const auto &allowed_transitions_tensor = context->input(2);
bool has_allowed_transitions =
allowed_transitions_tensor.NumElements() != 0;
VLOG(4) << allowed_transitions_tensor.NumElements();
if (has_allowed_transitions) {
OP_REQUIRES_OK(context,
ValidateConstraintTensor(allowed_transitions_tensor,
num_scores, use_start_end_states_,
"allowed_transitions"));
}
const auto &transition_weights_tensor = context->input(3);
VLOG(4) << transition_weights_tensor.NumElements();
bool has_transition_weights = transition_weights_tensor.NumElements() != 0;
if (has_transition_weights) {
OP_REQUIRES_OK(context, ValidateConstraintTensor(
transition_weights_tensor, num_scores,
use_start_end_states_, "transition_weights"));
// If we have transition weights in exp-space, all values must be non-
// negative.
if (!use_log_space_) {
for (int i = 0; i < transition_weights_tensor.NumElements(); ++i) {
OP_REQUIRES(context, transition_weights_tensor.flat<float>()(i) >= 0,
InvalidArgument("The transition weights tensor must not "
"contain negative values."));
}
}
}
const tensorflow::Tensor empty_float(DT_FLOAT, TensorShape({0, 0}));
const tensorflow::Tensor empty_bool(DT_BOOL, TensorShape({0, 0}));
const auto &transition_weights =
has_transition_weights ? transition_weights_tensor.matrix<float>()
: empty_float.matrix<float>();
const auto &allowed_transitions =
has_allowed_transitions ? allowed_transitions_tensor.matrix<bool>()
: empty_bool.matrix<bool>();
Tensor *output;
OP_REQUIRES_OK(context, context->allocate_output(
0, TensorShape({total_length}), &output));
int32 *output_data = output->flat<int32>().data();
Tensor *offsets;
OP_REQUIRES_OK(context, context->allocate_output(
1, TensorShape({batch_size + 1}), &offsets));
Tsplits *offset_data = offsets->flat<Tsplits>().data();
offset_data[0] = 0;
for (int batch = 0; batch < batch_size; ++batch) {
int step_offset = offset_data[batch];
int64 num_steps = scores.GetLength(batch);
offset_data[batch + 1] = step_offset + num_steps;
if (use_viterbi_) {
DoViterbiAnalysis(transition_weights, allowed_transitions, batch,
scores, &output_data[step_offset]);
} else {
DoGreedyAnalysis(transition_weights, allowed_transitions, batch, scores,
&output_data[step_offset]);
}
}
}