in mmdnn/conversion/coreml/coreml_emitter.py [0:0]
def gen_model(self,
input_names=None,
output_names=None,
image_input_names=None,
is_bgr=False,
red_bias=0.0,
green_bias=0.0,
blue_bias=0.0,
gray_bias=0.0,
image_scale=1.0,
class_labels=None,
predicted_feature_name=None,
predicted_probabilities_output=''):
input_features, output_features = self._get_inout()
is_classifier = class_labels is not None
mode = 'classifier' if is_classifier else None
self.builder = _NeuralNetworkBuilder(input_features, output_features, mode=mode)
for layer in self.IR_graph.topological_sort:
current_node = self.IR_graph.get_node(layer)
print("Converting layer {}({})".format(current_node.name, current_node.type))
node_type = current_node.type
if hasattr(self, "emit_" + node_type):
func = getattr(self, "emit_" + node_type)
func(current_node)
else:
print("CoreMLEmitter has not supported operator [%s]." % (node_type))
self.emit_UNKNOWN(current_node)
assert False
# Add classifier classes (if applicable)
if is_classifier:
classes_in = class_labels
if isinstance(classes_in, _string_types):
if not os.path.isfile(classes_in):
raise ValueError("Path to class labels [{}] does not exist.".format(classes_in))
with open(classes_in, 'r') as f:
classes = f.read()
classes = classes.splitlines()
elif type(classes_in) is list: # list[int or str]
classes = classes_in
else:
raise ValueError('Class labels must be a list of integers / strings, or a file path')
if predicted_feature_name is not None:
self.builder.set_class_labels(classes, predicted_feature_name = predicted_feature_name,
prediction_blob = predicted_probabilities_output)
else:
self.builder.set_class_labels(classes)
# Set pre-processing paramsters
self.builder.set_pre_processing_parameters(
image_input_names=[input_features[0][0]],
#image_input_names,
is_bgr=is_bgr,
red_bias=red_bias,
green_bias=green_bias,
blue_bias=blue_bias,
gray_bias=gray_bias,
image_scale=image_scale)
# Return the protobuf spec
# model = _MLModel(self.builder.spec)
print (self.builder.spec.description)
return self.builder.spec, input_features, output_features