in python/tvm/relay/frontend/onnx.py [0:0]
def from_onnx(self, graph, opset, get_output_expr=False):
"""Construct Relay expression from ONNX graph.
Onnx graph is a python protobuf object.
The companion parameters will be handled automatically.
However, the input names from onnx graph is vague, mixing inputs and
network weights/bias such as "1", "2"...
For convenience, we rename the `real` input names to "input_0",
"input_1"... And renaming parameters to "param_0", "param_1"...
Parameters
----------
graph : onnx protobuf object
The loaded onnx graph
opset : opset version
get_output_expr: bool
If set to true, this conversion will return each output expression rather
than a packaged module. This can be useful when converting subgraphs to
relay.
Returns
-------
mod : tvm.IRModule
The returned relay module
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
self.opset = opset
# parse network inputs to relay, aka parameters
for init_tensor in graph.initializer:
if not init_tensor.name.strip():
raise ValueError("Tensor's name is required.")
array = self._parse_array(init_tensor)
if self._freeze_params:
self._nodes[init_tensor.name] = _expr.const(array)
else:
self._params[init_tensor.name] = array
self._nodes[init_tensor.name] = new_var(
init_tensor.name,
shape=self._params[init_tensor.name].shape,
dtype=self._params[init_tensor.name].dtype,
)
for i in graph.input:
# from onnx v0.2, GraphProto.input has type ValueInfoProto,
# and the name is 'i.name'
i_name, i_shape, d_type, i_shape_name = get_info(i)
if i_name in self._params:
# i is a param instead of input
self._num_param += 1
self._params[i_name] = self._params.pop(i_name)
self._nodes[i_name] = new_var(
i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype
)
elif i_name in self._nodes:
continue
else:
self._num_input += 1
self._input_names.append(i_name)
if i_name in self._shape:
i_shape = self._shape[i_name]
else:
if "?" in str(i_shape):
warning_msg = (
"Input %s has unknown dimension shapes: %s. "
"Specifying static values may improve performance"
% (i_name, str(i_shape_name))
)
warnings.warn(warning_msg)
if isinstance(self._dtype, dict):
dtype = self._dtype[i_name] if i_name in self._dtype else d_type
else:
dtype = d_type
self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
self._inputs[i_name] = self._nodes[i_name]
# Only check user inputs in the outer-most graph scope.
if self._old_manager is None:
assert all(
[name in self._input_names for name in self._shape.keys()]
), "User specified the shape for inputs that weren't found in the graph: " + str(
self._shape
)
# get list of unsupported ops
convert_map = _get_convert_map(opset)
unsupported_ops = set()
for node in graph.node:
op_name = node.op_type
if (
op_name not in convert_map
and op_name != "Constant"
and op_name not in _identity_list
):
unsupported_ops.add(op_name)
if unsupported_ops:
msg = "The following operators are not supported for frontend ONNX: "
msg += ", ".join(unsupported_ops)
raise tvm.error.OpNotImplemented(msg)
# construct nodes, nodes are stored as directed acyclic graph
for node in graph.node:
op_name = node.op_type
attr = self._parse_attr(node.attribute)
# Create and populate onnx input object.
inputs = onnx_input()
for i in node.input:
if i != "":
inputs[i] = self._nodes[self._renames.get(i, i)]
else:
inputs[i] = None
i_name = self._parse_value_proto(node)
node_output = self._fix_outputs(op_name, node.output)
attr["tvm_custom"] = {}
attr["tvm_custom"]["name"] = i_name
attr["tvm_custom"]["num_outputs"] = len(node_output)
op = self._convert_operator(op_name, inputs, attr, opset)
if not isinstance(op, _expr.TupleWrapper):
outputs_num = 1
else:
outputs_num = len(op)
if outputs_num == 1:
op = fold_constant(op)
else:
op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))
if outputs_num > 1:
# ONNX supports optional outputs for some nodes.
# This block searches for missing outputs in the ONNX graph
# and removes any unneeded ops
valid_outputs = [False] * outputs_num
for i, output in enumerate(node_output):
if output != "":
valid_outputs[i] = True
# If we have outputs ONNX isn't expecting, we need to drop them
if not all(valid_outputs):
tup = op.astuple()
# TupleWrapper can also wrap ops with TupleType outputs
if isinstance(tup, _expr.Tuple):
# For tuples, we extract the fields instead of using GetTupleItem
outputs = [tup.fields[i] for i, valid in enumerate(valid_outputs) if valid]
else:
# For call nodes, we need to GetTupleItem
outputs = [op[i] for i, valid in enumerate(valid_outputs) if valid]
# Create the new op with valid outputs
if len(outputs) == 1:
op = outputs[0]
elif len(outputs) != outputs_num:
op = _expr.TupleWrapper(_expr.Tuple(outputs), len(outputs))
# Drop invalid outputs for the onnx node
outputs_num = len(outputs)
node_output = [output for output in node_output if output != ""]
assert (
len(node_output) == outputs_num
), "Number of output mismatch {} vs {} in {}.".format(
len(node_output), outputs_num, op_name
)
if outputs_num == 1:
self._nodes[node_output[0]] = op
else:
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
# now return the outputs
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
# If requested, directly return the converted expressions.
if get_output_expr:
return outputs
## Maintain the order of inputs and parameters from the ONNX graph, but only include
## those parameters that are needed to execute the relay graph
free_vars = analysis.free_vars(outputs)
nodes = {v: k for k, v in self._nodes.items()}
free_vars = [nodes[var] for var in free_vars]
for i_name in self._params:
if i_name in free_vars and i_name not in self._inputs:
self._inputs[i_name] = self._nodes[i_name]
# Create a function from our output expression and all input variables.
func = _function.Function([v for k, v in self._inputs.items()], outputs)
return IRModule.from_expr(func), self._params