in tensorflow_text/core/kernels/sentence_breaking_kernels.cc [149:255]
void Compute(::tensorflow::OpKernelContext* context) override {
#define DECLARE_AND_VALIDATE_INPUT_VECTOR(name, dtype) \
const Tensor* name##_tensor; \
OP_REQUIRES_OK(context, context->input(#name, &name##_tensor)); \
OP_REQUIRES(context, TensorShapeUtils::IsVector(name##_tensor->shape()), \
InvalidArgument( \
absl::StrCat("'", #name, "' must be a vector, got shape: ", \
name##_tensor->shape().DebugString()))); \
const auto& name = name##_tensor->vec<dtype>();
DECLARE_AND_VALIDATE_INPUT_VECTOR(row_lengths, int64);
DECLARE_AND_VALIDATE_INPUT_VECTOR(token_start, int64);
DECLARE_AND_VALIDATE_INPUT_VECTOR(token_end, int64);
DECLARE_AND_VALIDATE_INPUT_VECTOR(token_word, tstring);
DECLARE_AND_VALIDATE_INPUT_VECTOR(token_properties, int64);
#undef DECLARE_AND_VALIDATE_INPUT_TENSOR
static thread_local std::unique_ptr<WrappedConverter> input_encoder;
if (!input_encoder) {
input_encoder = absl::make_unique<WrappedConverter>();
}
input_encoder->init(input_encoding_);
OP_REQUIRES(
context, input_encoder->converter_,
InvalidArgument("Could not create converter for input encoding: " +
input_encoding_));
UConverter* converter = input_encoder->converter_;
UnicodeUtil util(converter);
int num_elements = 0;
for (int i = 0; i < row_lengths.size(); ++i) {
num_elements += row_lengths(i);
}
OP_REQUIRES(context,
num_elements == token_start.size() &&
token_start.size() == token_end.size() &&
token_end.size() == token_word.size(),
InvalidArgument(absl::StrCat(
"num_elements(", num_elements, "), token_start(",
token_start.size(), "), token_end(", token_end.size(),
"), token_word(", token_word.size(),
") must all be the same size.")));
// Iterate through the text
int token_index = 0;
int num_fragments = 0;
std::vector<std::vector<SentenceFragment>> fragments;
for (int i = 0; i < row_lengths.size(); ++i) {
std::vector<Token> tokens;
Document doc(&tokens);
for (int j = 0; j < row_lengths(i); ++j) {
doc.AddToken(
token_word(token_index), token_start(token_index),
token_end(token_index), Token::SPACE_BREAK,
static_cast<Token::TextProperty>(token_properties(token_index)));
++token_index;
}
// Find fragments.
SentenceFragmenter fragmenter(&doc, &util);
std::vector<SentenceFragment> frags;
OP_REQUIRES_OK(context, fragmenter.FindFragments(&frags));
num_fragments += frags.size();
fragments.push_back(std::move(frags));
}
std::vector<int64> fragment_shape;
fragment_shape.push_back(num_fragments);
std::vector<int64> doc_batch_shape;
doc_batch_shape.push_back(fragments.size());
#define DECLARE_OUTPUT_TENSOR(name, out_shape) \
Tensor* name##_tensor = nullptr; \
OP_REQUIRES_OK(context, context->allocate_output( \
#name, TensorShape(out_shape), &name##_tensor)); \
auto name = name##_tensor->vec<int64>();
DECLARE_OUTPUT_TENSOR(fragment_start, fragment_shape);
DECLARE_OUTPUT_TENSOR(fragment_end, fragment_shape);
DECLARE_OUTPUT_TENSOR(fragment_properties, fragment_shape);
DECLARE_OUTPUT_TENSOR(terminal_punc_token, fragment_shape);
DECLARE_OUTPUT_TENSOR(output_row_lengths, doc_batch_shape);
#undef DECLARE_OUTPUT_TENSOR
// output_row_splits should have shape of
// [number of fragments over the entire batch]
int element_index = 0;
// Iterate through all the documents
for (int i = 0; i < fragments.size(); ++i) {
const std::vector<SentenceFragment>& fragments_in_doc = fragments[i];
// Iterate through all the fragments of a document
for (int j = 0; j < fragments_in_doc.size(); ++j) {
const SentenceFragment& fragment = fragments_in_doc[j];
fragment_start(element_index) = fragment.start;
fragment_end(element_index) = fragment.limit;
fragment_properties(element_index) = fragment.properties;
terminal_punc_token(element_index) = fragment.terminal_punc_token;
++element_index;
}
output_row_lengths(i) = fragments_in_doc.size();
}
}