in tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc [197:281]
absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) {
struct_options_ = options;
int input_index = FindTensorIndex(
GetInputTensors(), GetMetadataExtractor()->GetInputTensorMetadata(),
options.input_tensor_name, options.input_tensor_index);
if (input_index < 0 || input_index >= GetInputCount()) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrCat("No input tensor found with name ",
options.input_tensor_name, " or at index ",
options.input_tensor_index),
TfLiteSupportStatus::kInputTensorNotFoundError);
}
// Create preprocessor.
ASSIGN_OR_RETURN(preprocessor_, processor::RegexPreprocessor::Create(
GetTfLiteEngine(), input_index));
// output score tensor should be type
// UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL
std::vector<const TfLiteTensor*> output_tensors = GetOutputTensors();
const Vector<Offset<TensorMetadata>>* output_tensor_metadatas =
GetMetadataExtractor()->GetOutputTensorMetadata();
const auto scores = FindTensorWithNameOrIndex(
output_tensors, output_tensor_metadatas, options.output_score_tensor_name,
options.output_score_tensor_index);
if (scores == nullptr) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrCat("No output score tensor found with name ",
options.output_score_tensor_name, " or at index ",
options.output_score_tensor_index),
TfLiteSupportStatus::kOutputTensorNotFoundError);
}
static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteInt8,
kTfLiteInt16, kTfLiteFloat32,
kTfLiteFloat64, kTfLiteBool};
if (!absl::c_linear_search(valid_types, scores->type)) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrCat("Type mismatch for score tensor ", scores->name,
". Requested one of these types: "
"INT8/UINT8/INT16/FLOAT32/FLOAT64/BOOL, got ",
TfLiteTypeGetName(scores->type), "."),
TfLiteSupportStatus::kInvalidOutputTensorTypeError);
}
// Extract associated label file from output score tensor if one exists, a
// well-formatted metadata should have same number of tensors with the model.
if (output_tensor_metadatas &&
output_tensor_metadatas->size() == output_tensors.size()) {
for (int i = 0; i < output_tensor_metadatas->size(); ++i) {
const tflite::TensorMetadata* metadata = output_tensor_metadatas->Get(i);
if ((metadata->name() && metadata->name()->string_view() ==
options.output_score_tensor_name) ||
i == options.output_score_tensor_index) {
if (TrySetLabelFromMetadata(metadata).ok()) {
return absl::OkStatus();
}
}
}
}
// If labels_vector_ is not set up from metadata, try register output label
// tensor from options.
if (labels_vector_ == nullptr) {
// output label tensor should be type STRING or INT32 if the one exists
auto labels = FindTensorWithNameOrIndex(
output_tensors, output_tensor_metadatas,
options.output_label_tensor_name, options.output_label_tensor_index);
if (labels != nullptr && labels->type != kTfLiteString &&
labels->type != kTfLiteInt32) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrCat("Type mismatch for label tensor ", scores->name,
". Requested STRING or INT32, got ",
TfLiteTypeGetName(scores->type), "."),
TfLiteSupportStatus::kInvalidOutputTensorTypeError);
}
}
return absl::OkStatus();
}