absl::Status Invoke()

in tensorflow_text/core/kernels/ngrams_kernel_template.h [124:215]


  absl::Status Invoke(InvokeContext* ctx) {
    using Tsplits = int64_t;
    // Storage for the dummy input and output row_splits used in the tensor
    // case.
    std::vector<Tsplits> tensor_input_row_splits;
    std::vector<Tsplits> tensor_output_row_splits;

    const Tsplits* input_row_splits;
    Tsplits* output_row_splits;
    int n_row_splits = 0;

    SH_ASSIGN_OR_RETURN(const auto input_values, ctx->GetInput(kValues));
    const Shape input_values_shape(input_values->Shape());

    // Tensor output
    if (ctx->NumOutputs() == 1) {
      // Generate mock input and output innermost row_splits.
      int64_t total_tokens =
          input_values->template Data<tensorflow::tstring>().size();
      int64_t tokens_per_element =
          input_values_shape->at(input_values_shape->size() - 1);
      tensor_output_row_splits.resize(total_tokens / tokens_per_element + 1);
      for (int64_t i = 0; i <= total_tokens; i += tokens_per_element) {
        tensor_input_row_splits.push_back(i);
      }
      input_row_splits = tensor_input_row_splits.data();
      output_row_splits = tensor_output_row_splits.data();
      n_row_splits = tensor_input_row_splits.size();
    } else {
      // RaggedTensor output
      int index = 0;
      const int num_row_splits = ctx->NumInputs() - kRowSplitsStart;
      while (index < num_row_splits - 1) {
        SH_ASSIGN_OR_RETURN(const auto input_tensor_row_splits,
                            ctx->GetInput(kRowSplitsStart + index));
        SH_ASSIGN_OR_RETURN(
            const auto output_tensor_row_splits,
            ctx->GetOutput(kRowSplitsStart + index,
                           Shape(input_tensor_row_splits->Shape())));
        const auto input_buffer =
            input_tensor_row_splits->template Data<Tsplits>();
        const auto output_buffer =
            output_tensor_row_splits->template Data<Tsplits>();
        std::memcpy(output_buffer.data(), input_buffer.data(),
                    input_buffer.size() * sizeof(Tsplits));
        ++index;
      }

      SH_ASSIGN_OR_RETURN(const auto input_tensor_row_splits,
                          ctx->GetInput(kRowSplitsStart + index));
      SH_ASSIGN_OR_RETURN(
          const auto output_tensor_row_splits,
          ctx->GetOutput(kRowSplitsStart + index,
                         Shape(input_tensor_row_splits->Shape())));
      input_row_splits =
          input_tensor_row_splits->template Data<Tsplits>().data();
      output_row_splits =
          output_tensor_row_splits->template Data<Tsplits>().data();
      n_row_splits = input_tensor_row_splits->Shape().at(0);
    }

    const auto input_values_data =
        input_values->template Data<tensorflow::tstring>();

    std::vector<std::string> buffer;
    for (int i = 0; i < n_row_splits - 1; ++i) {
      output_row_splits[i] = buffer.size();
      std::vector<tensorflow::tstring> tokens;
      for (int j = input_row_splits[i]; j < input_row_splits[i + 1]; ++j) {
        tokens.emplace_back(input_values_data.at(j));
        if (tokens.size() < width_) continue;
        tokens.erase(tokens.begin(), tokens.begin() + tokens.size() - width_);
        buffer.push_back(absl::StrJoin(tokens, string_separator_));
      }
    }
    output_row_splits[n_row_splits - 1] = buffer.size();

    tflite::shim::TensorViewOr output_values_or;
    if (ctx->NumOutputs() == 1) {
      output_values_or = ctx->GetOutput(
          kValues, OutputValuesTensorShape(input_values_shape, width_));
    } else {
      output_values_or =
          ctx->GetOutput(kValues, Shape({static_cast<int>(buffer.size())}));
    }
    if (!output_values_or.ok()) return output_values_or.status();
    auto& output_buffer =
        output_values_or.value()->template Data<tensorflow::tstring>();
    int i = 0;
    for (const auto& v : buffer) output_buffer[i++] = v;
    return absl::OkStatus();
  }