in tensorflow_text/core/kernels/sentencepiece_kernels.cc [412:526]
void Compute(OpKernelContext* ctx) override {
SentencepieceResource* sp;
const Tensor& resource_tensor = ctx->input(0);
ResourceHandle resource_handle(resource_tensor.scalar<ResourceHandle>()());
OP_REQUIRES_OK(
ctx, ctx->resource_manager()->Lookup<SentencepieceResource>(
resource_handle.container(), resource_handle.name(), &sp));
core::ScopedUnref unref_me(sp);
const Tensor& input_values_tensor = ctx->input(1);
const auto input_values_flat =
input_values_tensor.flat<tensorflow::tstring>();
const int64 num_of_input_values = input_values_flat.size();
const Tensor* nbest_size_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("nbest_size", &nbest_size_tensor));
const Tensor* alpha_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("alpha", &alpha_tensor));
OP_REQUIRES_OK(ctx, HandleExtraOptions(ctx, sp));
if (return_nbest_) {
OP_REQUIRES(ctx, nbest_size_tensor->dims() == 0,
errors::InvalidArgument(
"When return_nbest is true nbest_size must "
"be a scalar; got",
nbest_size_tensor->shape().DebugString(), "instead"));
OP_REQUIRES(ctx, nbest_size_tensor->scalar<int32>()() >= 1,
errors::InvalidArgument(
"When return_nbest is true nbest_size must be >= 1; got ",
nbest_size_tensor->scalar<int32>()()));
}
std::vector<sentencepiece::SentencePieceText> results(
return_nbest_ ? 0 : num_of_input_values);
std::vector<sentencepiece::NBestSentencePieceText> nbest_results(
return_nbest_ ? num_of_input_values : 0);
const bool return_nbest = return_nbest_;
const auto& worker_threads =
*(ctx->device()->tensorflow_cpu_worker_threads());
::tensorflow::Shard(
worker_threads.num_threads, // max parallelism
worker_threads.workers, // thread pool
num_of_input_values, // total number of data to process.
kCostPerUnit,
[ctx, sp, &input_values_flat, &results, &nbest_results,
&nbest_size_tensor, &alpha_tensor,
return_nbest](int64 start, int64 limit) {
absl::ReaderMutexLock lock(&sp->mu);
for (int i = start; i < limit; ++i) {
const int32 nbest_size = nbest_size_tensor->dims() == 1
? nbest_size_tensor->vec<int32>()(i)
: nbest_size_tensor->scalar<int32>()();
if (return_nbest) {
OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.NBestEncode(
input_values_flat(i), nbest_size,
&nbest_results[i])));
} else if (nbest_size == 0 || nbest_size == 1) {
OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.Encode(
input_values_flat(i), &results[i])));
} else {
const float alpha = alpha_tensor->dims() == 1
? alpha_tensor->vec<float>()(i)
: alpha_tensor->scalar<float>()();
OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.SampleEncode(
input_values_flat(i), nbest_size, alpha,
&results[i])));
}
}
});
if (return_nbest_) {
for (auto& nbest : nbest_results) {
for (auto& result : nbest.nbests()) {
results.push_back(std::move(result));
}
}
}
int64 total_tokens = 0;
for (auto& sp_result : results) {
total_tokens += sp_result.pieces_size();
}
Tensor* output_values_tensor = nullptr;
Tensor* output_splits_tensor = nullptr;
Tensor* output_starts_tensor = nullptr;
Tensor* output_limits_tensor = nullptr;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, {total_tokens}, &output_values_tensor));
int64 splits_size = results.size() + 1;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(1, {splits_size}, &output_splits_tensor));
OP_REQUIRES_OK(
ctx, ctx->allocate_output(2, {total_tokens}, &output_starts_tensor));
OP_REQUIRES_OK(
ctx, ctx->allocate_output(3, {total_tokens}, &output_limits_tensor));
auto values_tensor_flat = output_values_tensor->vec<T>();
auto splits_tensor_flat = output_splits_tensor->vec<Tsplits>();
auto starts_tensor_flat = output_starts_tensor->vec<int64>();
auto limits_tensor_flat = output_limits_tensor->vec<int64>();
int i = 0;
splits_tensor_flat(0) = 0;
for (int row = 0; row < results.size(); ++row) {
for (auto& sp : results[row].pieces()) {
values_tensor_flat(i) = GetPieceOrId<T>(sp);
starts_tensor_flat(i) = sp.begin();
limits_tensor_flat(i) = sp.end();
++i;
}
splits_tensor_flat(row + 1) = i;
}
}