in tensorflow_lite_support/metadata/python/metadata_writers/object_detector.py [0:0]
def create_from_metadata_info(
cls,
model_buffer: bytearray,
general_md: Optional[metadata_info.GeneralMd] = None,
input_md: Optional[metadata_info.InputImageTensorMd] = None,
output_location_md: Optional[metadata_info.TensorMd] = None,
output_category_md: Optional[metadata_info.CategoryTensorMd] = None,
output_score_md: Union[None, metadata_info.TensorMd,
metadata_info.ClassificationTensorMd] = None,
output_number_md: Optional[metadata_info.TensorMd] = None):
"""Creates MetadataWriter based on general/input/outputs information.
Args:
model_buffer: valid buffer of the model file.
general_md: general information about the model.
input_md: input image tensor informaton.
output_location_md: output location tensor informaton. The location tensor
is a multidimensional array of [N][4] floating point values between 0
and 1, the inner arrays representing bounding boxes in the form [top,
left, bottom, right].
output_category_md: output category tensor information. The category
tensor is an array of N integers (output as floating point values) each
indicating the index of a class label from the labels file.
output_score_md: output score tensor information. The score tensor is an
array of N floating point values between 0 and 1 representing
probability that a class was detected. Use ClassificationTensorMd to
calibrate score.
output_number_md: output number of detections tensor information. This
tensor is an integer value of N.
Returns:
A MetadataWriter object.
"""
if general_md is None:
general_md = metadata_info.GeneralMd(
name=_MODEL_NAME, description=_MODEL_DESCRIPTION)
if input_md is None:
input_md = metadata_info.InputImageTensorMd(
name=_INPUT_NAME,
description=_INPUT_DESCRIPTION,
color_space_type=_metadata_fb.ColorSpaceType.RGB)
warn_message_format = (
"The output name isn't the default string \"%s\". This may cause the "
"model not work in the TFLite Task Library since the tensor name will "
"be used to handle the output order in the TFLite Task Library.")
if output_location_md is None:
output_location_md = metadata_info.TensorMd(
name=_OUTPUT_LOCATION_NAME, description=_OUTPUT_LOCATION_DESCRIPTION)
elif output_location_md.name != _OUTPUT_LOCATION_NAME:
logging.warning(warn_message_format, _OUTPUT_LOCATION_NAME)
if output_category_md is None:
output_category_md = metadata_info.CategoryTensorMd(
name=_OUTPUT_CATRGORY_NAME, description=_OUTPUT_CATEGORY_DESCRIPTION)
elif output_category_md.name != _OUTPUT_CATRGORY_NAME:
logging.warning(warn_message_format, _OUTPUT_CATRGORY_NAME)
if output_score_md is None:
output_score_md = metadata_info.ClassificationTensorMd(
name=_OUTPUT_SCORE_NAME,
description=_OUTPUT_SCORE_DESCRIPTION,
)
elif output_score_md.name != _OUTPUT_SCORE_NAME:
logging.warning(warn_message_format, _OUTPUT_SCORE_NAME)
if output_number_md is None:
output_number_md = metadata_info.TensorMd(
name=_OUTPUT_NUMBER_NAME, description=_OUTPUT_NUMBER_DESCRIPTION)
elif output_number_md.name != _OUTPUT_NUMBER_NAME:
logging.warning(warn_message_format, _OUTPUT_NUMBER_NAME)
# Create output tensor group info.
group = _metadata_fb.TensorGroupT()
group.name = _GROUP_NAME
group.tensorNames = [
output_location_md.name, output_category_md.name, output_score_md.name
]
# Gets the tensor inidces of tflite outputs and then gets the order of the
# output metadata by the value of tensor indices. For instance, if the
# output indices are [601, 599, 598, 600], tensor names and indices aligned
# are:
# - location: 598
# - category: 599
# - score: 600
# - number of detections: 601
# because of the op's ports of TFLITE_DETECTION_POST_PROCESS
# (https://github.com/tensorflow/tensorflow/blob/a4fe268ea084e7d323133ed7b986e0ae259a2bc7/tensorflow/lite/kernels/detection_postprocess.cc#L47-L50).
# Thus, the metadata of tensors are sorted in this way, according to
# output_tensor_indicies correctly.
output_tensor_indices = _get_tflite_outputs(model_buffer)
metadata_list = [
_create_location_metadata(output_location_md),
_create_metadata_with_value_range(output_category_md),
_create_metadata_with_value_range(output_score_md),
output_number_md.create_metadata()
]
# Align indices with tensors.
sorted_indices = sorted(output_tensor_indices)
indices_to_tensors = dict(zip(sorted_indices, metadata_list))
# Output metadata according to output_tensor_indices.
output_metadata = [indices_to_tensors[i] for i in output_tensor_indices]
# Create subgraph info.
subgraph_metadata = _metadata_fb.SubGraphMetadataT()
subgraph_metadata.inputTensorMetadata = [input_md.create_metadata()]
subgraph_metadata.outputTensorMetadata = output_metadata
subgraph_metadata.outputTensorGroups = [group]
# Create model metadata
model_metadata = general_md.create_metadata()
model_metadata.subgraphMetadata = [subgraph_metadata]
b = flatbuffers.Builder(0)
b.Finish(
model_metadata.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
associated_files = []
_extend_new_files(associated_files, output_category_md.associated_files)
_extend_new_files(associated_files, output_score_md.associated_files)
return cls(model_buffer, b.Output(), associated_files=associated_files)