in model_optimizer_pkg/model_optimizer_pkg/model_optimizer_node.py [0:0]
def convert_to_mo_cli(self,
model_name,
model_metadata_sensors,
training_algorithm,
input_width,
input_height,
lidar_channels,
aux_inputs):
"""Helper method that converts the information in model optimizer API into
the appropriate cli commands.
Args:
model_name (str): Model prefix, should be the same in the weight and symbol file.
model_metadata_sensors (list): List of sensor input types(int) for all the sensors
with which the model was trained.
training_algorithm (int): Training algorithm key(int) for the algorithm with which
the model was trained.
input_width (int): Width of the input image to the inference engine.
input_height (int): Height of the input image to the inference engine.
lidar_channels (int): Number of LiDAR values that with which the LiDAR head of
the model was trained.
aux_inputs (dict): Dictionary of auxiliary options for the model optimizer.
Raises:
Exception: Custom exception if the API flags and default values are not
aligned.
Exception: Custom exception if the lidar_channel value is less than 1.
Returns:
dict: Map of parameters to be passed to model optimizer command based on the model.
"""
if len(constants.APIFlags.get_list()) != len(constants.APIDefaults.get_list()):
raise Exception("Inconsistent API flags")
# Set the flags tot he default values.
default_param = {}
for flag, value in zip(constants.APIFlags.get_list(), constants.APIDefaults.get_list()):
default_param[flag] = value
# Set param values to the values to the user entered values in aux_inputs.
for flag, value in aux_inputs.items():
if flag in default_param:
default_param[flag] = value
# Dictionary that will house the cli commands.
common_params = {}
# Convert API information into appropriate cli commands.
for flag, value in default_param.items():
if flag is constants.APIFlags.MODELS_DIR:
common_params[constants.MOKeys.MODEL_PATH] = os.path.join(value, model_name)
# Input shape is in the for [n,h,w,c] to support tensorflow models only
elif flag is constants.APIFlags.IMG_CHANNEL:
common_params[constants.MOKeys.INPUT_SHAPE] = (constants.MOKeys.INPUT_SHAPE_FMT
.format(1,
input_height,
input_width,
value))
elif flag is constants.APIFlags.PRECISION:
common_params[constants.MOKeys.DATA_TYPE] = value
elif flag is constants.APIFlags.FUSE:
if value is not constants.APIDefaults.FUSE:
common_params[constants.MOKeys.DISABLE_FUSE] = ""
common_params[constants.MOKeys.DISABLE_GFUSE] = ""
elif flag is constants.APIFlags.IMG_FORMAT:
if value is constants.APIDefaults.IMG_FORMAT:
common_params[constants.MOKeys.REV_CHANNELS] = ""
elif flag is constants.APIFlags.OUT_DIR:
common_params[constants.MOKeys.OUT_DIR] = value
# Only keep entries with non-empty string values.
elif value:
common_params[flag] = value
# Override the input shape and the input flags to handle multi head inputs in tensorflow
input_shapes = []
input_names = []
training_algorithm_key = constants.TrainingAlgorithms(training_algorithm)
for input_type in model_metadata_sensors:
input_key = constants.SensorInputTypes(input_type)
if input_key == constants.SensorInputTypes.LIDAR \
or input_key == constants.SensorInputTypes.SECTOR_LIDAR:
if lidar_channels < 1:
raise Exception("Lidar channels less than 1")
input_shapes.append(constants.INPUT_SHAPE_FORMAT_MAPPING[input_key]
.format(1, lidar_channels))
else:
# Input shape is in the for [n,h,w,c] to support tensorflow models only
input_shapes.append(
constants.INPUT_SHAPE_FORMAT_MAPPING[input_key]
.format(1,
input_height,
input_width,
constants.INPUT_CHANNEL_SIZE_MAPPING[input_key]))
input_name_format = constants.NETWORK_INPUT_FORMAT_MAPPING[input_key]
input_names.append(
input_name_format.format(
constants.INPUT_HEAD_NAME_MAPPING[training_algorithm_key]))
if len(input_names) > 0 and len(input_shapes) == len(input_names):
common_params[constants.MOKeys.INPUT_SHAPE] = \
constants.MOKeys.INPUT_SHAPE_DELIM.join(input_shapes)
common_params[constants.APIFlags.INPUT] = \
constants.MOKeys.INPUT_SHAPE_DELIM.join(input_names)
common_params[constants.MOKeys.MODEL_NAME] = model_name
return common_params