absl::Status NLClassifier::Initialize()

in tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc [197:281]


absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) {
  struct_options_ = options;

  int input_index = FindTensorIndex(
      GetInputTensors(), GetMetadataExtractor()->GetInputTensorMetadata(),
      options.input_tensor_name, options.input_tensor_index);

  if (input_index < 0 || input_index >= GetInputCount()) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrCat("No input tensor found with name ",
                     options.input_tensor_name, " or at index ",
                     options.input_tensor_index),
        TfLiteSupportStatus::kInputTensorNotFoundError);
  }

  // Create preprocessor.
  ASSIGN_OR_RETURN(preprocessor_, processor::RegexPreprocessor::Create(
                                      GetTfLiteEngine(), input_index));

  // output score tensor should be type
  // UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL
  std::vector<const TfLiteTensor*> output_tensors = GetOutputTensors();
  const Vector<Offset<TensorMetadata>>* output_tensor_metadatas =
      GetMetadataExtractor()->GetOutputTensorMetadata();

  const auto scores = FindTensorWithNameOrIndex(
      output_tensors, output_tensor_metadatas, options.output_score_tensor_name,
      options.output_score_tensor_index);
  if (scores == nullptr) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrCat("No output score tensor found with name ",
                     options.output_score_tensor_name, " or at index ",
                     options.output_score_tensor_index),
        TfLiteSupportStatus::kOutputTensorNotFoundError);
  }
  static constexpr TfLiteType valid_types[] = {kTfLiteUInt8,   kTfLiteInt8,
                                               kTfLiteInt16,   kTfLiteFloat32,
                                               kTfLiteFloat64, kTfLiteBool};
  if (!absl::c_linear_search(valid_types, scores->type)) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrCat("Type mismatch for score tensor ", scores->name,
                     ". Requested one of these types: "
                     "INT8/UINT8/INT16/FLOAT32/FLOAT64/BOOL, got ",
                     TfLiteTypeGetName(scores->type), "."),
        TfLiteSupportStatus::kInvalidOutputTensorTypeError);
  }

  // Extract associated label file from output score tensor if one exists, a
  // well-formatted metadata should have same number of tensors with the model.
  if (output_tensor_metadatas &&
      output_tensor_metadatas->size() == output_tensors.size()) {
    for (int i = 0; i < output_tensor_metadatas->size(); ++i) {
      const tflite::TensorMetadata* metadata = output_tensor_metadatas->Get(i);
      if ((metadata->name() && metadata->name()->string_view() ==
                                   options.output_score_tensor_name) ||
          i == options.output_score_tensor_index) {
        if (TrySetLabelFromMetadata(metadata).ok()) {
          return absl::OkStatus();
        }
      }
    }
  }

  // If labels_vector_ is not set up from metadata, try register output label
  // tensor from options.
  if (labels_vector_ == nullptr) {
    // output label tensor should be type STRING or INT32 if the one exists
    auto labels = FindTensorWithNameOrIndex(
        output_tensors, output_tensor_metadatas,
        options.output_label_tensor_name, options.output_label_tensor_index);
    if (labels != nullptr && labels->type != kTfLiteString &&
        labels->type != kTfLiteInt32) {
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          absl::StrCat("Type mismatch for label tensor ", scores->name,
                       ". Requested STRING or INT32, got ",
                       TfLiteTypeGetName(scores->type), "."),
          TfLiteSupportStatus::kInvalidOutputTensorTypeError);
    }
  }
  return absl::OkStatus();
}