in python/tvm/relay/frontend/tensorflow.py [0:0]
def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Follow the tensorflow graph definition to parse and convert it to Relay.
Some of the assumptions listed below.
-> All Placeholders are considered as graph input.
-> All Const nodes are params.
-> Last node is assumed as graph output.
-> _output_shapes : Graph should be frozen with add_shapes=True.
Or user can pass input shape dictionary optionally.
-> DecodeJpeg, ResizeBilinear: These are dummy operators.
Hence user should handle preprocessing outside.
-> CheckNumerics: No implementation as of now for this.
Just copies input to output.
Parameters
----------
graph : tensorflow graph definition object
The loaded tensorflow GraphDef
layout : target layout to be used (Optional)
NCHW only supported now to enable NHWC models on GPU.
shape : Dictionary of input dimensions (Optional)
Graph level input shape dictionary.
outputs : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.
Returns
-------
mod : tvm.IRModule
The module that optimizations will be performed on.
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
try:
from tensorflow.python.framework import tensor_util
except ImportError as e:
raise ImportError("Unable to import tensorflow which is required {}".format(e))
missing_operators = self._parse_import_prerequisites(graph)
control_flow_nodes = []
ta_write_nodes = []
ta_gather_nodes = []
ta_construct_nodes = []
self._in_shape = shape
self._layout = layout
self._graph = graph
if missing_operators:
freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list]
if freezed_ops:
raise Exception(
"Graph is not frozen. Provide a frozen graph. "
"Found operators {}".format(freezed_ops)
)
raise NotImplementedError(
"The following operators are not implemented: {}".format(missing_operators)
)
for node in graph.node:
node_name_prefix = node.name.rsplit("/", 1)[0]
self._control_flow_node_map[node_name_prefix].add(node.op)
self._tf_node_map[node.name] = node
# Parse output_shapes attribute
parsed_attr = self._parse_attr(node.attr)
if "_output_shapes" in parsed_attr:
self._output_shapes[node.name] = [
tensor_util.TensorShapeProtoToList(tshape)
for tshape in parsed_attr["_output_shapes"]
]
else:
self._output_shapes[node.name] = [None]
# Parse placeholder and const here since input shape info is required.
if node.op == "Placeholder" or node.op == "PlaceholderWithDefault":
# Give priority to user argument.
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
else:
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(
node.attr["shape"].shape
)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
self._input_shapes[node.name][idx] = Any()
self._output_shapes[node.name] = [self._input_shapes[node.name]]
attr = self._parse_attr(node.attr)
self._nodes[node.name] = [
_expr.var(
node.name, shape=self._input_shapes[node.name], dtype=attr["dtype"].name
)
]
# Ignore user's input shape for Non placeholder
elif node.op == "Const":
tensor_value = node.attr["value"].tensor
self._input_shapes[node.name] = tensor_util.TensorShapeProtoToList(
tensor_value.tensor_shape
)
self._output_shapes[node.name] = [self._input_shapes[node.name]]
if shape and node.name in shape:
warnings.warn(
"Ignore the passed shape. Shape in graphdef "
"will be used for operator %s." % node.name
)
for key, value in node.attr.items():
self._parse_param(key, value, node.name, self._in_shape)
elif node.op in _control_flow_nodes:
# We assume that the direct parent node of Exit is a while loop block
if node.op == "Exit":
self._while_loop_name_set.add(node_name_prefix)
control_flow_nodes.append(node)
elif node.op.startswith("TensorArray"):
if is_tensor_array_constuctor(node):
ta_construct_nodes.append(node)
else:
for ta_write_name, idx in _tensor_array_write_ops.items():
if node.op.startswith(ta_write_name):
ta_write_nodes.append((node, idx))
break
if node.op.startswith("TensorArrayGather"):
ta_gather_nodes.append(node)
# Use tensor array gather to infer static tensor array shape
for gather_node in ta_gather_nodes:
input_ta_name = gather_node.input[0]
input_ta_node = self._tf_node_map[input_ta_name]
if is_tensor_array_constuctor(input_ta_node):
gather_attr = self._parse_attr(gather_node.attr)
if "element_shape" not in gather_attr:
continue
raw_elem_shape = tensor_util.TensorShapeProtoToList(gather_attr["element_shape"])
elem_shape = []
for dim in raw_elem_shape:
if dim < 0:
elem_shape.append(Any())
else:
elem_shape.append(int(dim))
self._tensor_array_shapes[input_ta_node.name] = elem_shape
# Fetch node contains static tensor array shape
for item in ta_write_nodes:
wnode = item[0]
ta_idx, inode_idx = item[1]
stack = [self._tf_node_map[wnode.input[ta_idx].split(":")[0]]]
while stack:
cnode = stack.pop(0)
if not cnode.op.startswith("TensorArray"):
for iname in cnode.input:
stack.append(self._tf_node_map[iname.split(":")[0]])
elif cnode.name != wnode.name:
if is_tensor_array_constuctor(cnode):
inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]]
tn = wnode.input[inode_idx].split(":")
output_index = int(tn[1]) if len(tn) > 1 else 0
self._tensor_array_shape_nodes[cnode.name] = (inode, wnode.op, output_index)
break
# First, parse all control flow nodes.
# Convert tf.cond to Branch and tf.while_loop to Loop.
sorted_cf_nodes = []
exit_pos_map = {}
ordered_prefix = []
# Sort control flow nodes to move all Exit nodes to the end
# of corresponding while_loop block.
for node in control_flow_nodes:
loop_name = find_parent_loop_name(node.name, self._while_loop_name_set)
if node.op == "Exit":
if loop_name not in exit_pos_map:
ordered_prefix.append(loop_name)
exit_pos_map[loop_name] = len(sorted_cf_nodes)
sorted_cf_nodes.append(node)
elif loop_name in self._while_loop_name_set:
if loop_name not in exit_pos_map:
sorted_cf_nodes.append(node)
else:
sorted_cf_nodes.insert(exit_pos_map[loop_name], node)
for j in range(ordered_prefix.index(loop_name), len(ordered_prefix)):
exit_pos_map[ordered_prefix[j]] += 1
else:
sorted_cf_nodes.append(node)
for node in sorted_cf_nodes:
self._sorted_cf_node_names.append(node.name)
for node in sorted_cf_nodes:
self._backtrack_construct(node.name)
# Second, parse other nodes to re-create TF graph using Relay operators.
for node in graph.node:
self._backtrack_construct(node.name)
out = []
if outputs is None:
last_node = graph.node[-1]
op = self._nodes[last_node.name.split(":")[0]]
if last_node.op == "Exit":
out = [op[0].tuple_value]
else:
out = op
else:
for out_name in outputs:
if ":" in out_name:
out_name, out_num = out_name.split(":")
out_num = int(out_num)
out.append(self._nodes[out_name][out_num])
else:
out.append(self._nodes[out_name][0])
if isinstance(out, _expr.TupleWrapper):
out = out.tuple_value
else:
out = out[0] if len(out) == 1 else _expr.Tuple(out)
fvars = analysis.free_vars(out)
func = _function.Function(fvars, out)
final_params = {}
for fv in fvars:
if fv.name_hint in self._params:
final_params[fv.name_hint] = self._params[fv.name_hint]
self._params = final_params
return func