void Compute()

in tensorflow_text/core/kernels/sentencepiece_kernels.cc [412:526]


  void Compute(OpKernelContext* ctx) override {
    SentencepieceResource* sp;
    const Tensor& resource_tensor = ctx->input(0);
    ResourceHandle resource_handle(resource_tensor.scalar<ResourceHandle>()());
    OP_REQUIRES_OK(
        ctx, ctx->resource_manager()->Lookup<SentencepieceResource>(
                 resource_handle.container(), resource_handle.name(), &sp));
    core::ScopedUnref unref_me(sp);

    const Tensor& input_values_tensor = ctx->input(1);
    const auto input_values_flat =
        input_values_tensor.flat<tensorflow::tstring>();
    const int64 num_of_input_values = input_values_flat.size();

    const Tensor* nbest_size_tensor = nullptr;
    OP_REQUIRES_OK(ctx, ctx->input("nbest_size", &nbest_size_tensor));
    const Tensor* alpha_tensor = nullptr;
    OP_REQUIRES_OK(ctx, ctx->input("alpha", &alpha_tensor));

    OP_REQUIRES_OK(ctx, HandleExtraOptions(ctx, sp));

    if (return_nbest_) {
      OP_REQUIRES(ctx, nbest_size_tensor->dims() == 0,
                  errors::InvalidArgument(
                      "When return_nbest is true nbest_size must "
                      "be a scalar; got",
                      nbest_size_tensor->shape().DebugString(), "instead"));
      OP_REQUIRES(ctx, nbest_size_tensor->scalar<int32>()() >= 1,
                  errors::InvalidArgument(
                      "When return_nbest is true nbest_size must be >= 1; got ",
                      nbest_size_tensor->scalar<int32>()()));
    }

    std::vector<sentencepiece::SentencePieceText> results(
        return_nbest_ ? 0 : num_of_input_values);
    std::vector<sentencepiece::NBestSentencePieceText> nbest_results(
        return_nbest_ ? num_of_input_values : 0);
    const bool return_nbest = return_nbest_;
    const auto& worker_threads =
        *(ctx->device()->tensorflow_cpu_worker_threads());
    ::tensorflow::Shard(
        worker_threads.num_threads,  // max parallelism
        worker_threads.workers,      // thread pool
        num_of_input_values,         // total number of data to process.
        kCostPerUnit,
        [ctx, sp, &input_values_flat, &results, &nbest_results,
         &nbest_size_tensor, &alpha_tensor,
         return_nbest](int64 start, int64 limit) {
          absl::ReaderMutexLock lock(&sp->mu);
          for (int i = start; i < limit; ++i) {
            const int32 nbest_size = nbest_size_tensor->dims() == 1
                                         ? nbest_size_tensor->vec<int32>()(i)
                                         : nbest_size_tensor->scalar<int32>()();
            if (return_nbest) {
              OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.NBestEncode(
                                      input_values_flat(i), nbest_size,
                                      &nbest_results[i])));
            } else if (nbest_size == 0 || nbest_size == 1) {
              OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.Encode(
                                      input_values_flat(i), &results[i])));
            } else {
              const float alpha = alpha_tensor->dims() == 1
                                      ? alpha_tensor->vec<float>()(i)
                                      : alpha_tensor->scalar<float>()();
              OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.SampleEncode(
                                      input_values_flat(i), nbest_size, alpha,
                                      &results[i])));
            }
          }
        });

    if (return_nbest_) {
      for (auto& nbest : nbest_results) {
        for (auto& result : nbest.nbests()) {
          results.push_back(std::move(result));
        }
      }
    }
    int64 total_tokens = 0;
    for (auto& sp_result : results) {
      total_tokens += sp_result.pieces_size();
    }

    Tensor* output_values_tensor = nullptr;
    Tensor* output_splits_tensor = nullptr;
    Tensor* output_starts_tensor = nullptr;
    Tensor* output_limits_tensor = nullptr;

    OP_REQUIRES_OK(
        ctx, ctx->allocate_output(0, {total_tokens}, &output_values_tensor));
    int64 splits_size = results.size() + 1;
    OP_REQUIRES_OK(
        ctx, ctx->allocate_output(1, {splits_size}, &output_splits_tensor));
    OP_REQUIRES_OK(
        ctx, ctx->allocate_output(2, {total_tokens}, &output_starts_tensor));
    OP_REQUIRES_OK(
        ctx, ctx->allocate_output(3, {total_tokens}, &output_limits_tensor));

    auto values_tensor_flat = output_values_tensor->vec<T>();
    auto splits_tensor_flat = output_splits_tensor->vec<Tsplits>();
    auto starts_tensor_flat = output_starts_tensor->vec<int64>();
    auto limits_tensor_flat = output_limits_tensor->vec<int64>();

    int i = 0;
    splits_tensor_flat(0) = 0;
    for (int row = 0; row < results.size(); ++row) {
      for (auto& sp : results[row].pieces()) {
        values_tensor_flat(i) = GetPieceOrId<T>(sp);
        starts_tensor_flat(i) = sp.begin();
        limits_tensor_flat(i) = sp.end();
        ++i;
      }
      splits_tensor_flat(row + 1) = i;
    }
  }