in tensorflow_text/core/kernels/ngrams_kernel_template.h [124:215]
absl::Status Invoke(InvokeContext* ctx) {
using Tsplits = int64_t;
// Storage for the dummy input and output row_splits used in the tensor
// case.
std::vector<Tsplits> tensor_input_row_splits;
std::vector<Tsplits> tensor_output_row_splits;
const Tsplits* input_row_splits;
Tsplits* output_row_splits;
int n_row_splits = 0;
SH_ASSIGN_OR_RETURN(const auto input_values, ctx->GetInput(kValues));
const Shape input_values_shape(input_values->Shape());
// Tensor output
if (ctx->NumOutputs() == 1) {
// Generate mock input and output innermost row_splits.
int64_t total_tokens =
input_values->template Data<tensorflow::tstring>().size();
int64_t tokens_per_element =
input_values_shape->at(input_values_shape->size() - 1);
tensor_output_row_splits.resize(total_tokens / tokens_per_element + 1);
for (int64_t i = 0; i <= total_tokens; i += tokens_per_element) {
tensor_input_row_splits.push_back(i);
}
input_row_splits = tensor_input_row_splits.data();
output_row_splits = tensor_output_row_splits.data();
n_row_splits = tensor_input_row_splits.size();
} else {
// RaggedTensor output
int index = 0;
const int num_row_splits = ctx->NumInputs() - kRowSplitsStart;
while (index < num_row_splits - 1) {
SH_ASSIGN_OR_RETURN(const auto input_tensor_row_splits,
ctx->GetInput(kRowSplitsStart + index));
SH_ASSIGN_OR_RETURN(
const auto output_tensor_row_splits,
ctx->GetOutput(kRowSplitsStart + index,
Shape(input_tensor_row_splits->Shape())));
const auto input_buffer =
input_tensor_row_splits->template Data<Tsplits>();
const auto output_buffer =
output_tensor_row_splits->template Data<Tsplits>();
std::memcpy(output_buffer.data(), input_buffer.data(),
input_buffer.size() * sizeof(Tsplits));
++index;
}
SH_ASSIGN_OR_RETURN(const auto input_tensor_row_splits,
ctx->GetInput(kRowSplitsStart + index));
SH_ASSIGN_OR_RETURN(
const auto output_tensor_row_splits,
ctx->GetOutput(kRowSplitsStart + index,
Shape(input_tensor_row_splits->Shape())));
input_row_splits =
input_tensor_row_splits->template Data<Tsplits>().data();
output_row_splits =
output_tensor_row_splits->template Data<Tsplits>().data();
n_row_splits = input_tensor_row_splits->Shape().at(0);
}
const auto input_values_data =
input_values->template Data<tensorflow::tstring>();
std::vector<std::string> buffer;
for (int i = 0; i < n_row_splits - 1; ++i) {
output_row_splits[i] = buffer.size();
std::vector<tensorflow::tstring> tokens;
for (int j = input_row_splits[i]; j < input_row_splits[i + 1]; ++j) {
tokens.emplace_back(input_values_data.at(j));
if (tokens.size() < width_) continue;
tokens.erase(tokens.begin(), tokens.begin() + tokens.size() - width_);
buffer.push_back(absl::StrJoin(tokens, string_separator_));
}
}
output_row_splits[n_row_splits - 1] = buffer.size();
tflite::shim::TensorViewOr output_values_or;
if (ctx->NumOutputs() == 1) {
output_values_or = ctx->GetOutput(
kValues, OutputValuesTensorShape(input_values_shape, width_));
} else {
output_values_or =
ctx->GetOutput(kValues, Shape({static_cast<int>(buffer.size())}));
}
if (!output_values_or.ok()) return output_values_or.status();
auto& output_buffer =
output_values_or.value()->template Data<tensorflow::tstring>();
int i = 0;
for (const auto& v : buffer) output_buffer[i++] = v;
return absl::OkStatus();
}