void Compute()

in tensorflow_text/core/kernels/sentence_breaking_kernels.cc [149:255]


  void Compute(::tensorflow::OpKernelContext* context) override {
#define DECLARE_AND_VALIDATE_INPUT_VECTOR(name, dtype)                        \
  const Tensor* name##_tensor;                                                \
  OP_REQUIRES_OK(context, context->input(#name, &name##_tensor));             \
  OP_REQUIRES(context, TensorShapeUtils::IsVector(name##_tensor->shape()),    \
              InvalidArgument(                                                \
                  absl::StrCat("'", #name, "' must be a vector, got shape: ", \
                               name##_tensor->shape().DebugString())));       \
  const auto& name = name##_tensor->vec<dtype>();

    DECLARE_AND_VALIDATE_INPUT_VECTOR(row_lengths, int64);
    DECLARE_AND_VALIDATE_INPUT_VECTOR(token_start, int64);
    DECLARE_AND_VALIDATE_INPUT_VECTOR(token_end, int64);
    DECLARE_AND_VALIDATE_INPUT_VECTOR(token_word, tstring);
    DECLARE_AND_VALIDATE_INPUT_VECTOR(token_properties, int64);

#undef DECLARE_AND_VALIDATE_INPUT_TENSOR

    static thread_local std::unique_ptr<WrappedConverter> input_encoder;
    if (!input_encoder) {
      input_encoder = absl::make_unique<WrappedConverter>();
    }
    input_encoder->init(input_encoding_);
    OP_REQUIRES(
        context, input_encoder->converter_,
        InvalidArgument("Could not create converter for input encoding: " +
                        input_encoding_));

    UConverter* converter = input_encoder->converter_;
    UnicodeUtil util(converter);

    int num_elements = 0;
    for (int i = 0; i < row_lengths.size(); ++i) {
      num_elements += row_lengths(i);
    }
    OP_REQUIRES(context,
                num_elements == token_start.size() &&
                    token_start.size() == token_end.size() &&
                    token_end.size() == token_word.size(),
                InvalidArgument(absl::StrCat(
                    "num_elements(", num_elements, "), token_start(",
                    token_start.size(), "), token_end(", token_end.size(),
                    "), token_word(", token_word.size(),
                    ") must all be the same size.")));

    // Iterate through the text
    int token_index = 0;
    int num_fragments = 0;
    std::vector<std::vector<SentenceFragment>> fragments;
    for (int i = 0; i < row_lengths.size(); ++i) {
      std::vector<Token> tokens;
      Document doc(&tokens);
      for (int j = 0; j < row_lengths(i); ++j) {
        doc.AddToken(
            token_word(token_index), token_start(token_index),
            token_end(token_index), Token::SPACE_BREAK,
            static_cast<Token::TextProperty>(token_properties(token_index)));
        ++token_index;
      }

      // Find fragments.
      SentenceFragmenter fragmenter(&doc, &util);
      std::vector<SentenceFragment> frags;
      OP_REQUIRES_OK(context, fragmenter.FindFragments(&frags));

      num_fragments += frags.size();
      fragments.push_back(std::move(frags));
    }

    std::vector<int64> fragment_shape;
    fragment_shape.push_back(num_fragments);

    std::vector<int64> doc_batch_shape;
    doc_batch_shape.push_back(fragments.size());

#define DECLARE_OUTPUT_TENSOR(name, out_shape)                                 \
  Tensor* name##_tensor = nullptr;                                             \
  OP_REQUIRES_OK(context, context->allocate_output(                            \
                              #name, TensorShape(out_shape), &name##_tensor)); \
  auto name = name##_tensor->vec<int64>();

    DECLARE_OUTPUT_TENSOR(fragment_start, fragment_shape);
    DECLARE_OUTPUT_TENSOR(fragment_end, fragment_shape);
    DECLARE_OUTPUT_TENSOR(fragment_properties, fragment_shape);
    DECLARE_OUTPUT_TENSOR(terminal_punc_token, fragment_shape);
    DECLARE_OUTPUT_TENSOR(output_row_lengths, doc_batch_shape);

#undef DECLARE_OUTPUT_TENSOR

    // output_row_splits should have shape of
    // [number of fragments over the entire batch]
    int element_index = 0;
    // Iterate through all the documents
    for (int i = 0; i < fragments.size(); ++i) {
      const std::vector<SentenceFragment>& fragments_in_doc = fragments[i];
      // Iterate through all the fragments of a document
      for (int j = 0; j < fragments_in_doc.size(); ++j) {
        const SentenceFragment& fragment = fragments_in_doc[j];
        fragment_start(element_index) = fragment.start;
        fragment_end(element_index) = fragment.limit;
        fragment_properties(element_index) = fragment.properties;
        terminal_punc_token(element_index) = fragment.terminal_punc_token;
        ++element_index;
      }
      output_row_lengths(i) = fragments_in_doc.size();
    }
  }