def extract_subgraph()

in coremltools/converters/mil/frontend/tensorflow/tfssa.py [0:0]


    def extract_subgraph(self, outputs, target_inputs=None, name=""):
        """Add a new SSAFunction to the current NetworkEnsemble to produce the given outputs.

        Args:
            outputs: The outputs the new function must produce.
            target_inputs:
            name: The name of the new function to create. If unspecified, a name will be generated
                  by joining output names.
        Returns:
            The name of the new function.
        """
        if not isinstance(outputs, list):
            raise TypeError("Expected a list of output names for subgraph extraction")

        if name == "":
            outputs.sort()
            name = escape_fn_name("_".join(outputs))

        if target_inputs is None:
            target_inputs = []

        def DFS_inputs(graph, node, vis):
            vis.add(node)
            if node in target_inputs:
                return [node]
            if (
                len(graph[node].inputs) == 0
                and len(graph[node].control_inputs) == 0
                and graph[node].op != "Const"
            ):
                return [node]
            inputs = []
            for i in graph[node].inputs + graph[node].control_inputs:
                if i in vis:
                    continue
                inputs += DFS_inputs(graph, i, vis)
            return inputs

        def DFS_set_globals(graph, node, vis):
            vis.add(node)
            set_globals = []
            if graph[node].op == "set_global":
                set_globals.append(node)
            for i in graph[node].outputs + graph[node].control_outputs:
                if i in vis:
                    continue
                set_globals += DFS_set_globals(graph, i, vis)
            return set_globals

        for k in list(self.functions.keys()):
            v = self.functions[k]
            extract = []
            for output in outputs:
                if output in v.graph:
                    extract.append(output)

            if len(extract) == 0:
                continue
            incl_nodes = set()
            gdict = copy.deepcopy(v.graph)
            inputs = []
            set_globals = []
            for output in extract:
                inputs += DFS_inputs(gdict, output, incl_nodes)
            vis_nodes = set()
            for inp in inputs:
                set_globals += DFS_set_globals(gdict, inp, vis_nodes)
            for node in set_globals:
                inputs += DFS_inputs(gdict, node, incl_nodes)

            for new_k, new_v in v.graph.items():
                if new_k not in incl_nodes:
                    del gdict[new_k]
                    continue
                if new_k in target_inputs:
                    gdict[new_k].op = "Placeholder"
                gdict[new_k].inputs = [inp for inp in new_v.inputs if inp in incl_nodes]
                gdict[new_k].outputs = [
                    out for out in new_v.outputs if out in incl_nodes
                ]
                gdict[new_k].control_inputs = [
                    inp for inp in new_v.control_inputs if inp in incl_nodes
                ]
                gdict[new_k].control_outputs = [
                    out for out in new_v.control_outputs if out in incl_nodes
                ]

            for output in extract:
                old_name = "preIdentity_" + output
                output_node = copy.deepcopy(gdict[output])
                output_node.op = "Identity"
                output_node.inputs = [old_name]
                output_node.control_inputs = []
                output_node.outputs = []
                output_node.control_outputs = []

                for inp in gdict[output].inputs:
                    for idx, out in enumerate(gdict[inp].outputs):
                        if out == output:
                            gdict[inp].outputs[idx] = old_name
                for inp in gdict[output].control_inputs:
                    for idx, out in enumerate(gdict[inp].control_outputs):
                        if out == output:
                            gdict[inp].control_outputs[idx] = old_name
                for out in gdict[output].outputs:
                    for idx, inp in enumerate(gdict[out].inputs):
                        if inp == output:
                            gdict[out].inputs[idx] = old_name
                for out in gdict[output].control_outputs:
                    for idx, inp in enumerate(gdict[out].control_inputs):
                        if inp == output:
                            gdict[out].control_inputs[idx] = old_name
                gdict[output].outputs.append(output)
                gdict[output].name = old_name
                gdict[old_name] = gdict[output]
                gdict[output] = output_node

            self.functions[name] = SSAFunction(gdict)
        return name