in tensorflow_text/core/kernels/tokenizer_from_logits_kernel.cc [128:223]
void Compute(OpKernelContext* ctx) override {
const Tensor* strings;
OP_REQUIRES_OK(ctx, ctx->input("strings", &strings));
const Tensor* logits;
OP_REQUIRES_OK(ctx, ctx->input("logits", &logits));
OP_REQUIRES(ctx, strings->dim_size(0) == logits->dim_size(0),
errors::InvalidArgument("Expecting logits to have ",
strings->dim_size(0),
" rows, got ",
logits->dim_size(0)));
const Tensor* force_split_at_break_character;
OP_REQUIRES_OK(ctx, ctx->input("force_split_at_break_character",
&force_split_at_break_character));
const bool force_split_at_break_character_bool =
force_split_at_break_character->scalar<bool>()();
std::vector<string> tokens;
std::vector<int> begin_offset;
std::vector<int> end_offset;
std::vector<int> output_row_splits(1, 0);
// Tensor to access values from logits.
const TTypes<const float, 3>::Tensor logits_tensor =
logits->tensor<float, 3>();
// Iterate through all the values and tokenize them.
const auto& strings_vec = strings->flat<tstring>();
OP_REQUIRES(ctx, logits_tensor.dimension(0) >= strings_vec.size(),
errors::Internal("Bad logits dimension #0: ",
logits_tensor.dimension(0), " < ",
strings_vec.size()));
// Dimension #1 of logits will be checked inside TokenizeByLogits.
OP_REQUIRES(ctx, logits_tensor.dimension(2) == 2,
errors::Internal("Bad logits dimension #2: ",
logits_tensor.dimension(2), " != 2"));
for (int i = 0; i < strings_vec.size(); ++i) {
// Tokenize into tokens and record the offset locations.
int num_tokens = 0;
OP_REQUIRES_OK(
ctx, TokenizeByLogits(
strings_vec(i),
logits_tensor, i,
force_split_at_break_character_bool,
&tokens, &begin_offset, &end_offset, &num_tokens));
// Record the row splits.
output_row_splits.push_back(num_tokens + output_row_splits.back());
}
std::vector<int64> output_tokens_shape;
output_tokens_shape.push_back(tokens.size());
std::vector<int64> output_row_splits_shape;
output_row_splits_shape.push_back(output_row_splits.size());
Tensor* output_values;
OP_REQUIRES_OK(ctx, ctx->allocate_output("output_values",
TensorShape(output_tokens_shape),
&output_values));
auto output_values_vec = output_values->vec<tstring>();
Tensor* output_row_splits_tensor;
OP_REQUIRES_OK(ctx,
ctx->allocate_output("row_splits",
TensorShape(output_row_splits_shape),
&output_row_splits_tensor));
auto output_row_splits_vec = output_row_splits_tensor->vec<int64>();
Tensor* start_values;
OP_REQUIRES_OK(ctx, ctx->allocate_output("start_values",
TensorShape(output_tokens_shape),
&start_values));
auto start_values_vec = start_values->vec<int64>();
Tensor* limit_values;
OP_REQUIRES_OK(ctx, ctx->allocate_output("limit_values",
TensorShape(output_tokens_shape),
&limit_values));
auto limit_values_vec = limit_values->vec<int64>();
for (int i = 0; i < tokens.size(); ++i) {
output_values_vec(i) = tokens[i];
}
for (int i = 0; i < output_row_splits.size(); ++i) {
output_row_splits_vec(i) = output_row_splits[i];
}
for (int i = 0; i < begin_offset.size(); ++i) {
start_values_vec(i) = begin_offset[i];
}
for (int i = 0; i < end_offset.size(); ++i) {
limit_values_vec(i) = end_offset[i];
}
}