in tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc [137:250]
StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
const TfLiteEngine::Interpreter& interpreter,
const tflite::metadata::ModelMetadataExtractor& metadata_extractor) {
ASSIGN_OR_RETURN(const TensorMetadata* metadata,
GetInputTensorMetadataIfAny(metadata_extractor));
const ImageProperties* props = nullptr;
absl::optional<NormalizationOptions> normalization_options;
if (metadata != nullptr) {
ASSIGN_OR_RETURN(props, GetImagePropertiesIfAny(*metadata));
ASSIGN_OR_RETURN(normalization_options,
GetNormalizationOptionsIfAny(*metadata));
}
if (TfLiteEngine::InputCount(&interpreter) != 1) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Models are assumed to have a single input.",
TfLiteSupportStatus::kInvalidNumInputTensorsError);
}
// Input-related specifications.
const TfLiteTensor* input_tensor = TfLiteEngine::GetInput(&interpreter, 0);
if (input_tensor->dims->size != 4) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Only 4D tensors in BHWD layout are supported.",
TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
}
static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteFloat32};
TfLiteType input_type = input_tensor->type;
if (!absl::c_linear_search(valid_types, input_type)) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrCat(
"Type mismatch for input tensor ", input_tensor->name,
". Requested one of these types: kTfLiteUint8/kTfLiteFloat32, got ",
TfLiteTypeGetName(input_type), "."),
TfLiteSupportStatus::kInvalidInputTensorTypeError);
}
// The expected layout is BHWD, i.e. batch x height x width x color
// See https://www.tensorflow.org/guide/tensors
const int batch = input_tensor->dims->data[0];
const int height = input_tensor->dims->data[1];
const int width = input_tensor->dims->data[2];
const int depth = input_tensor->dims->data[3];
if (props != nullptr && props->color_space() != ColorSpaceType_RGB) {
return CreateStatusWithPayload(StatusCode::kInvalidArgument,
"Only RGB color space is supported for now.",
TfLiteSupportStatus::kInvalidArgumentError);
}
if (batch != 1 || depth != 3) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
absl::StrCat("The input tensor should have dimensions 1 x height x "
"width x 3. Got ",
batch, " x ", height, " x ", width, " x ", depth, "."),
TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
}
int bytes_size = input_tensor->bytes;
size_t byte_depth =
input_type == kTfLiteFloat32 ? sizeof(float) : sizeof(uint8);
// Sanity checks.
if (input_type == kTfLiteFloat32) {
if (!normalization_options.has_value()) {
return CreateStatusWithPayload(
absl::StatusCode::kNotFound,
"Input tensor has type kTfLiteFloat32: it requires specifying "
"NormalizationOptions metadata to preprocess input images.",
TfLiteSupportStatus::kMetadataMissingNormalizationOptionsError);
} else if (bytes_size / sizeof(float) %
normalization_options.value().num_values !=
0) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"The number of elements in the input tensor must be a multiple of "
"the number of normalization parameters.",
TfLiteSupportStatus::kInvalidArgumentError);
}
}
if (width <= 0) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument, "The input width should be positive.",
TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
}
if (height <= 0) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument, "The input height should be positive.",
TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
}
if (bytes_size != height * width * depth * byte_depth) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"The input size in bytes does not correspond to the expected number of "
"pixels.",
TfLiteSupportStatus::kInvalidInputTensorSizeError);
}
// Note: in the future, additional checks against `props->default_size()`
// might be added. Also, verify that NormalizationOptions, if any, do specify
// a single value when color space is grayscale.
ImageTensorSpecs result;
result.image_width = width;
result.image_height = height;
result.color_space = ColorSpaceType_RGB;
result.tensor_type = input_type;
result.normalization_options = normalization_options;
return result;
}