in python/graph_def_util.py [0:0]
def restore_compiler_failures(compiled_graph_def, original_graph_def):
"""
Restore `NeuronOp`'s that failed to compile.
TODO: Some passes introduced recently can change subgraph input/output
signatures. To deal with these cases properly, we need to obtain original
NodeDef messages from `original_graph_def` instead of `subgraph_def`.
"""
neuron_op_dict = {node.name: node for node in get_neuron_nodes(compiled_graph_def)}
restore_nodes = []
remove_node_names = set()
gd_tensor_name_map = {}
all_expected_node_names = {node.name for node in compiled_graph_def.node if node.op != tNeuronOp}
for node in get_neuron_nodes(compiled_graph_def):
if not node.attr[knExecutable].s:
remove_node_names.add(node.name)
subgraph_def = get_subgraph_def(node)
sgd_tensor_name_map = {}
for gd_ts_name, sg_ph_name in zip(node.input, node.attr[knInputNames].list.s):
sgd_ph_name = format_tensor_name(sg_ph_name.decode())
op_name, ts_index = _graph_def_op_index(gd_ts_name)
if op_name in neuron_op_dict:
in_node = neuron_op_dict[op_name]
if not in_node.attr[knExecutable].s:
gd_ts_name = in_node.attr[knOutputNames].list.s[ts_index].decode()
sgd_tensor_name_map[sgd_ph_name] = gd_ts_name
for sg_node in subgraph_def.node:
for idx, name in enumerate(sg_node.input):
sg_node.input[idx] = sgd_tensor_name_map.get(name, name)
if sg_node.op != tPlaceholder:
restore_nodes.append(sg_node)
all_expected_node_names.add(sg_node.name)
for out_idx, out_name in enumerate(node.attr[knOutputNames].list.s):
out_gd_ts_name = format_tensor_name('{}:{}'.format(node.name, out_idx))
gd_tensor_name_map[out_gd_ts_name] = format_tensor_name(out_name.decode())
restore_node_names = {node.name for node in restore_nodes}
remove_node_names.update(
node.name for node in compiled_graph_def.node if node.name in restore_node_names)
original_node_with_control_inputs = get_node_with_control_inputs(original_graph_def)
for node in restore_nodes:
if node.name in original_node_with_control_inputs:
input_names = original_node_with_control_inputs[node.name]
for name in input_names:
if name.split(':')[0] in all_expected_node_names:
node.input.append(name)
for node in compiled_graph_def.node:
for idx, name in enumerate(node.input):
node.input[idx] = gd_tensor_name_map.get(name, name)
graph_def = graph_pb2.GraphDef()
graph_def.node.extend(
node for node in compiled_graph_def.node if node.name not in remove_node_names)
graph_def.node.extend(node for node in restore_nodes)
# remove illegal node names
node_names = {node.name for node in graph_def.node}
for node in graph_def.node:
node.input[:] = [name for name in node.input if _graph_def_op_index(name)[0] in node_names]
# preserve information for function-call operators (e. g., MapDataset)
graph_def.library.CopyFrom(compiled_graph_def.library)
return graph_def