in coremltools/converters/mil/frontend/tensorflow/tfssa.py [0:0]
def extract_subgraph(self, outputs, target_inputs=None, name=""):
"""Add a new SSAFunction to the current NetworkEnsemble to produce the given outputs.
Args:
outputs: The outputs the new function must produce.
target_inputs:
name: The name of the new function to create. If unspecified, a name will be generated
by joining output names.
Returns:
The name of the new function.
"""
if not isinstance(outputs, list):
raise TypeError("Expected a list of output names for subgraph extraction")
if name == "":
outputs.sort()
name = escape_fn_name("_".join(outputs))
if target_inputs is None:
target_inputs = []
def DFS_inputs(graph, node, vis):
vis.add(node)
if node in target_inputs:
return [node]
if (
len(graph[node].inputs) == 0
and len(graph[node].control_inputs) == 0
and graph[node].op != "Const"
):
return [node]
inputs = []
for i in graph[node].inputs + graph[node].control_inputs:
if i in vis:
continue
inputs += DFS_inputs(graph, i, vis)
return inputs
def DFS_set_globals(graph, node, vis):
vis.add(node)
set_globals = []
if graph[node].op == "set_global":
set_globals.append(node)
for i in graph[node].outputs + graph[node].control_outputs:
if i in vis:
continue
set_globals += DFS_set_globals(graph, i, vis)
return set_globals
for k in list(self.functions.keys()):
v = self.functions[k]
extract = []
for output in outputs:
if output in v.graph:
extract.append(output)
if len(extract) == 0:
continue
incl_nodes = set()
gdict = copy.deepcopy(v.graph)
inputs = []
set_globals = []
for output in extract:
inputs += DFS_inputs(gdict, output, incl_nodes)
vis_nodes = set()
for inp in inputs:
set_globals += DFS_set_globals(gdict, inp, vis_nodes)
for node in set_globals:
inputs += DFS_inputs(gdict, node, incl_nodes)
for new_k, new_v in v.graph.items():
if new_k not in incl_nodes:
del gdict[new_k]
continue
if new_k in target_inputs:
gdict[new_k].op = "Placeholder"
gdict[new_k].inputs = [inp for inp in new_v.inputs if inp in incl_nodes]
gdict[new_k].outputs = [
out for out in new_v.outputs if out in incl_nodes
]
gdict[new_k].control_inputs = [
inp for inp in new_v.control_inputs if inp in incl_nodes
]
gdict[new_k].control_outputs = [
out for out in new_v.control_outputs if out in incl_nodes
]
for output in extract:
old_name = "preIdentity_" + output
output_node = copy.deepcopy(gdict[output])
output_node.op = "Identity"
output_node.inputs = [old_name]
output_node.control_inputs = []
output_node.outputs = []
output_node.control_outputs = []
for inp in gdict[output].inputs:
for idx, out in enumerate(gdict[inp].outputs):
if out == output:
gdict[inp].outputs[idx] = old_name
for inp in gdict[output].control_inputs:
for idx, out in enumerate(gdict[inp].control_outputs):
if out == output:
gdict[inp].control_outputs[idx] = old_name
for out in gdict[output].outputs:
for idx, inp in enumerate(gdict[out].inputs):
if inp == output:
gdict[out].inputs[idx] = old_name
for out in gdict[output].control_outputs:
for idx, inp in enumerate(gdict[out].control_inputs):
if inp == output:
gdict[out].control_inputs[idx] = old_name
gdict[output].outputs.append(output)
gdict[output].name = old_name
gdict[old_name] = gdict[output]
gdict[output] = output_node
self.functions[name] = SSAFunction(gdict)
return name