StatusOr BuildInputImageTensorSpecs()

in tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc [137:250]


StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
    const TfLiteEngine::Interpreter& interpreter,
    const tflite::metadata::ModelMetadataExtractor& metadata_extractor) {
  ASSIGN_OR_RETURN(const TensorMetadata* metadata,
                   GetInputTensorMetadataIfAny(metadata_extractor));

  const ImageProperties* props = nullptr;
  absl::optional<NormalizationOptions> normalization_options;
  if (metadata != nullptr) {
    ASSIGN_OR_RETURN(props, GetImagePropertiesIfAny(*metadata));
    ASSIGN_OR_RETURN(normalization_options,
                     GetNormalizationOptionsIfAny(*metadata));
  }

  if (TfLiteEngine::InputCount(&interpreter) != 1) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        "Models are assumed to have a single input.",
        TfLiteSupportStatus::kInvalidNumInputTensorsError);
  }

  // Input-related specifications.
  const TfLiteTensor* input_tensor = TfLiteEngine::GetInput(&interpreter, 0);
  if (input_tensor->dims->size != 4) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        "Only 4D tensors in BHWD layout are supported.",
        TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
  }
  static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteFloat32};
  TfLiteType input_type = input_tensor->type;
  if (!absl::c_linear_search(valid_types, input_type)) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrCat(
            "Type mismatch for input tensor ", input_tensor->name,
            ". Requested one of these types: kTfLiteUint8/kTfLiteFloat32, got ",
            TfLiteTypeGetName(input_type), "."),
        TfLiteSupportStatus::kInvalidInputTensorTypeError);
  }

  // The expected layout is BHWD, i.e. batch x height x width x color
  // See https://www.tensorflow.org/guide/tensors
  const int batch = input_tensor->dims->data[0];
  const int height = input_tensor->dims->data[1];
  const int width = input_tensor->dims->data[2];
  const int depth = input_tensor->dims->data[3];

  if (props != nullptr && props->color_space() != ColorSpaceType_RGB) {
    return CreateStatusWithPayload(StatusCode::kInvalidArgument,
                                   "Only RGB color space is supported for now.",
                                   TfLiteSupportStatus::kInvalidArgumentError);
  }
  if (batch != 1 || depth != 3) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        absl::StrCat("The input tensor should have dimensions 1 x height x "
                     "width x 3. Got ",
                     batch, " x ", height, " x ", width, " x ", depth, "."),
        TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
  }
  int bytes_size = input_tensor->bytes;
  size_t byte_depth =
      input_type == kTfLiteFloat32 ? sizeof(float) : sizeof(uint8);

  // Sanity checks.
  if (input_type == kTfLiteFloat32) {
    if (!normalization_options.has_value()) {
      return CreateStatusWithPayload(
          absl::StatusCode::kNotFound,
          "Input tensor has type kTfLiteFloat32: it requires specifying "
          "NormalizationOptions metadata to preprocess input images.",
          TfLiteSupportStatus::kMetadataMissingNormalizationOptionsError);
    } else if (bytes_size / sizeof(float) %
                   normalization_options.value().num_values !=
               0) {
      return CreateStatusWithPayload(
          StatusCode::kInvalidArgument,
          "The number of elements in the input tensor must be a multiple of "
          "the number of normalization parameters.",
          TfLiteSupportStatus::kInvalidArgumentError);
    }
  }
  if (width <= 0) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument, "The input width should be positive.",
        TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
  }
  if (height <= 0) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument, "The input height should be positive.",
        TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
  }
  if (bytes_size != height * width * depth * byte_depth) {
    return CreateStatusWithPayload(
        StatusCode::kInvalidArgument,
        "The input size in bytes does not correspond to the expected number of "
        "pixels.",
        TfLiteSupportStatus::kInvalidInputTensorSizeError);
  }

  // Note: in the future, additional checks against `props->default_size()`
  // might be added. Also, verify that NormalizationOptions, if any, do specify
  // a single value when color space is grayscale.

  ImageTensorSpecs result;
  result.image_width = width;
  result.image_height = height;
  result.color_space = ColorSpaceType_RGB;
  result.tensor_type = input_type;
  result.normalization_options = normalization_options;

  return result;
}