void Compute()

in tensorflow_text/core/kernels/sentencepiece_kernels.cc [271:377]


  void Compute(OpKernelContext* ctx) override {
    SentencepieceResource* sp;
    const Tensor& resource_tensor = ctx->input(0);
    ResourceHandle resource_handle(resource_tensor.scalar<ResourceHandle>()());
    OP_REQUIRES_OK(
        ctx, ctx->resource_manager()->Lookup<SentencepieceResource>(
                 resource_handle.container(), resource_handle.name(), &sp));
    core::ScopedUnref unref_me(sp);

    const Tensor& input_values_tensor = ctx->input(1);
    const auto input_values_flat =
        input_values_tensor.flat<tensorflow::tstring>();
    const int64 num_of_input_values = input_values_flat.size();

    const Tensor* nbest_size_tensor = nullptr;
    OP_REQUIRES_OK(ctx, ctx->input("nbest_size", &nbest_size_tensor));
    const Tensor* alpha_tensor = nullptr;
    OP_REQUIRES_OK(ctx, ctx->input("alpha", &alpha_tensor));

    OP_REQUIRES_OK(ctx, HandleExtraOptions(ctx, sp));

    if (return_nbest_) {
      OP_REQUIRES(ctx, nbest_size_tensor->dims() == 0,
                  errors::InvalidArgument(
                      "When return_nbest is true nbest_size must "
                      "be a scalar; got",
                      nbest_size_tensor->shape().DebugString(), "instead"));
      OP_REQUIRES(ctx, nbest_size_tensor->scalar<int32>()() >= 1,
                  errors::InvalidArgument(
                      "When return_nbest is true nbest_size must be >= 1; got ",
                      nbest_size_tensor->scalar<int32>()()));
    }

    std::vector<std::vector<typename std::conditional<
        std::is_same<T, tstring>::value, std::string, T>::type>>
        tokens(return_nbest_ ? 0 : num_of_input_values);
    std::vector<std::vector<std::vector<typename std::conditional<
        std::is_same<T, tstring>::value, std::string, T>::type>>>
        nbest_tokens(return_nbest_ ? num_of_input_values : 0);
    const bool return_nbest = return_nbest_;
    const auto& worker_threads =
        *(ctx->device()->tensorflow_cpu_worker_threads());
    ::tensorflow::Shard(
        worker_threads.num_threads,  // max parallelism
        worker_threads.workers,      // thread pool
        num_of_input_values,         // total number of data to process.
        kCostPerUnit,                // cost per unit
        [ctx, sp, &input_values_flat, &tokens, &nbest_tokens,
         &nbest_size_tensor, &alpha_tensor,
         return_nbest](int64 start, int64 limit) {
          absl::ReaderMutexLock lock(&sp->mu);
          for (int i = start; i < limit; ++i) {
            const int32 nbest_size = nbest_size_tensor->dims() == 1
                                         ? nbest_size_tensor->vec<int32>()(i)
                                         : nbest_size_tensor->scalar<int32>()();
            if (return_nbest) {
              OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.NBestEncode(
                                      input_values_flat(i), nbest_size,
                                      &nbest_tokens[i])));
            } else if (nbest_size == 0 || nbest_size == 1) {
              OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.Encode(
                                      input_values_flat(i), &tokens[i])));
            } else {
              const float alpha = alpha_tensor->dims() == 1
                                      ? alpha_tensor->vec<float>()(i)
                                      : alpha_tensor->scalar<float>()();
              OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.SampleEncode(
                                      input_values_flat(i), nbest_size, alpha,
                                      &tokens[i])));
            }
          }
        });

    if (return_nbest_) {
      for (auto& col : nbest_tokens) {
        for (auto& row : col) {
          tokens.push_back(std::move(row));
        }
      }
      nbest_tokens.clear();
    }
    int64 total_tokens = 0;
    for (auto& tokens_row : tokens) {
      total_tokens += tokens_row.size();
    }

    Tensor* output_values_tensor = nullptr;
    Tensor* output_splits_tensor = nullptr;

    OP_REQUIRES_OK(
        ctx, ctx->allocate_output(0, {total_tokens}, &output_values_tensor));
    int64 splits_size = tokens.size() + 1;
    OP_REQUIRES_OK(
        ctx, ctx->allocate_output(1, {splits_size}, &output_splits_tensor));

    auto values_tensor_flat = output_values_tensor->vec<T>();
    auto splits_tensor_flat = output_splits_tensor->vec<Tsplits>();

    int i = 0;
    splits_tensor_flat(0) = 0;
    for (int row = 0; row < tokens.size(); ++row) {
      for (int col = 0; col < tokens[row].size(); ++col, ++i) {
        values_tensor_flat(i) = tokens[row][col];
      }
      splits_tensor_flat(row + 1) = i;
    }
  }