in onnxconverter_common/topology.py [0:0]
def convert_topology(topology, model_name, doc_string, target_opset, targeted_onnx=None, channel_first_inputs=None):
'''
This function is used to convert our Topology object defined in _parser.py into a ONNX model (type: ModelProto).
:param topology: The Topology object we are going to convert
:param model_name: GraphProto's name. Let "model" denote the returned model. The string "model_name" would be
assigned to "model.graph.name."
:param doc_string: A string attached to the produced model
:param target_opset: number, for example, 7 for ONNX 1.2, and 8 for ONNX 1.3.
:param targeted_onnx[deprecated]: A string, which specifies the targeted ONNX version of the produced model.
Possible values include '1.1.2', '1.2', and so on.
:return: a ONNX ModelProto
'''
if targeted_onnx is not None and StrictVersion(targeted_onnx) != StrictVersion(onnx.__version__):
warnings.warn(
'targeted_onnx is deprecated, please specify target_opset for the target model.\n' +
'*** ONNX version conflict found. The installed version is %s while the targeted version is %s' % (
onnx.__version__, targeted_onnx))
opset_from_onnx_version = get_maximum_opset_supported()
if target_opset is None:
target_opset = opset_from_onnx_version
elif target_opset > opset_from_onnx_version:
raise RuntimeError(("target_opset %d is higher than the number of the installed onnx package"
+ " or the converter support (%d).") % (target_opset, opset_from_onnx_version))
topology._initialize_graph_status_for_traversing()
container = ModelComponentContainer(target_opset)
# Put roots and leaves as ONNX's model into buffers. They will be added into ModelComponentContainer later.
tensor_inputs = {}
other_inputs = {}
tensor_outputs = {}
other_outputs = {}
for scope in topology.scopes:
for variable in scope.variables.values():
if variable.is_root:
if isinstance(variable.type, (TensorType, Int64Type, FloatType, StringType)):
tensor_inputs[variable.raw_name] = variable
else:
other_inputs[variable.raw_name] = variable
if variable.is_leaf:
if isinstance(variable.type, (TensorType, Int64Type, FloatType, StringType)):
tensor_outputs[variable.raw_name] = variable
else:
other_outputs[variable.raw_name] = variable
# Add roots the graph according to their order in the original model
invalid_name = []
nhwc_inputs = []
if channel_first_inputs is None:
channel_first_inputs = []
for name in topology.raw_model.input_names:
# Check input naming convention
input_name = name.replace('_', '').replace(":", "").replace("/", "")
if input_name and (input_name[0].isdigit() or (not input_name.isalnum())):
invalid_name.append(name)
if name in tensor_inputs:
onnx_input = tensor_inputs[name] # type: Variable
if name in channel_first_inputs or \
(name.endswith(':0') and name[:-2] in channel_first_inputs):
nhwc_inputs.append(onnx_input.full_name)
s = onnx_input.type.shape
onnx_input.type.shape = [s[0], s[3], s[1], s[2]]
container.add_input(onnx_input)
if invalid_name:
warnings.warn('Some input names are not compliant with ONNX naming convention: %s' % invalid_name)
for name in topology.raw_model.input_names:
if name in other_inputs:
container.add_input(other_inputs[name])
# Add leaves the graph according to their order in the original model
invalid_name = []
for name in topology.raw_model.output_names:
# Check output naming convention
output_name = name.replace('_', '').replace(":", "").replace("/", "")
if output_name and (output_name[0].isdigit() or (not output_name.isalnum())):
invalid_name.append(name)
if name in tensor_outputs:
container.add_output(tensor_outputs[name])
if invalid_name:
warnings.warn('Some output names are not compliant with ONNX naming convention: %s' % invalid_name)
for name in topology.raw_model.output_names:
if name in other_outputs:
container.add_output(other_outputs[name])
# Traverse the graph from roots to leaves
for operator in topology.topological_operator_iterator():
scope = next(scope for scope in topology.scopes if scope.name == operator.scope)
if operator.type in topology.custom_conversion_functions:
topology.custom_conversion_functions[operator.type](scope, operator, container)
else:
# Convert the selected operator into some ONNX objects and save them into the container
get_converter(operator.type)(scope, operator, container)
# When calling ModelComponentContainer's add_initializer(...), nothing is added into the input list.
# However, for ONNX target opset < 9, initializers should also be model's (GraphProto) inputs.
# Thus, we create ValueInfoProto objects from initializers (type: TensorProto) directly and
# then add them into model's input list.
extra_inputs = [] # ValueInfoProto list of the initializers
for tensor in container.initializers:
# Sometimes (especially when creating optional input values such as RNN's initial hidden state), an initializer
# is also one of the original model's input, so it has been added into the container's input list. If this is
# the case, we need to skip one iteration to avoid duplicated inputs.
if tensor.name in [value_info.name for value_info in container.inputs]:
continue
# Initializers are always tensors so we can just call make_tensor_value_info(...)
value_info = helper.make_tensor_value_info(tensor.name, tensor.data_type, tensor.dims)
extra_inputs.append(value_info)
# enable the ONNX optimizations
if container.enable_optimizer:
nodes = optimize_onnx(container.nodes, nhwc_inputs, container.inputs + extra_inputs, container.outputs)
else:
nodes = container.nodes
# Create a graph from its main components
if container.target_opset < 9:
# Before ONNX opset 9, initializers need to be passed in with inputs
graph = helper.make_graph(nodes, model_name, container.inputs + extra_inputs,
container.outputs, container.initializers)
else:
# In ONNX opset 9 and above, initializers are included as operator
# inputs, and therefore do not need to be passed as extra_inputs
graph = helper.make_graph(nodes, model_name, container.inputs,
container.outputs, container.initializers)
# Add extra information related to the graph
graph.value_info.extend(container.value_info)
onnx_model = make_model_ex(graph, container.node_domain_version_pair_sets,
target_opset, topology.metadata_props, doc_string=doc_string)
return onnx_model