def functionalize_loops()

in coremltools/converters/mil/frontend/tensorflow/tf_graph_pass/functionalize_loops.py [0:0]


    def functionalize_loops(self, tfssa, function_to_functionalize):
        g = tfssa.functions[function_to_functionalize].graph
        loopni = [a for a in g if g[a].op == "Enter"]
        if len(loopni) == 0:
            return False
        self._search(g, loopni[0])

        self.constant_enters = [
            self.enters[i] for i in range(len(self.enters)) if self.is_constant[i]
        ]
        self.enters = [
            self.enters[i] for i in range(len(self.enters)) if not self.is_constant[i]
        ]
        self._fix_graph_invariants(g)
        # for each enter node, find the corresponding downstream merge node
        enter_corresponding_merge = [
            FindImmediateDownstreamNodes(lambda x: x.op == "Merge")
            .visit(g, enter)
            .get_result()[0]
            for enter in self.enters
        ]
        merge_corresponding_ni = [
            FindImmediateUpstreamNodes(lambda x: x.op == "NextIteration")
            .visit(g, merge)
            .get_result()[0]
            for merge in enter_corresponding_merge
        ]
        switch_corresponding_merge = []
        for merge in enter_corresponding_merge:
            switch_after_merge = (
                FindImmediateDownstreamNodes(lambda x: x.op == "Switch")
                .visit(g, merge)
                .get_result()
            )
            if len(switch_after_merge) > 0:
                switch_corresponding_merge.append(switch_after_merge[0])
            else:
                # There are some situations there is no switch not for a given
                # merge. While odd... its ok. we construct one
                # In this situation there is no Exit either, but it can be
                # constructed later on
                new_switch_node = ParsedTFNode()
                new_switch_node.op = "Switch"
                new_switch_node.name = tfssa._find_free_name("fake_switch_")
                g[new_switch_node.name] = new_switch_node
                connect_edge(g, merge, new_switch_node.name)
                connect_edge(g, self.loopcond[0], new_switch_node.name)
                switch_corresponding_merge.append(new_switch_node.name)

        exit_corresponding_switch = []
        for switch in switch_corresponding_merge:
            res = (
                FindImmediateDownstreamNodes(lambda x: x.op == "Exit")
                .visit(g, switch)
                .get_result()
            )
            if len(res) > 0:
                exit_corresponding_switch.append(res[0])
            else:
                new_exit_node = ParsedTFNode()
                new_exit_node.op = "Exit"
                new_exit_node.name = tfssa._find_free_name("fake_exit_")
                g[new_exit_node.name] = new_exit_node
                connect_edge(g, switch, new_exit_node.name)
                exit_corresponding_switch.append(new_exit_node.name)

        while_loop = ParsedTFNode()
        while_loop.op = "while"
        while_loop.name = tfssa._find_free_name("while_")
        g[while_loop.name] = while_loop

        # Build the Loop Condition

        # replace all enters with a single make_tuple
        # we replace merge with get_tuple and turn it into a function call
        # terminated with LoopCond
        make_inputs = ParsedTFNode()
        make_inputs.op = "make_tuple"
        make_inputs.name = tfssa._find_free_name("make_input_")
        g[make_inputs.name] = make_inputs
        for enter in self.enters:
            replace_dest(g, g[enter].inputs[0], enter, make_inputs.name)
        constant_base_index = len(make_inputs.inputs)
        for enter in self.constant_enters:
            replace_dest(g, g[enter].inputs[0], enter, make_inputs.name)

        connect_edge(g, make_inputs.name, while_loop.name)
        connect_dests(g, while_loop.name, exit_corresponding_switch)

        # build the cond function
        cond_body = ParsedTFNode()
        cond_body.op = "function_entry"
        cond_body.name = tfssa._find_free_name("cond_function_")
        cond_body.inputs = []
        g[cond_body.name] = cond_body
        for merge_idx in range(len(enter_corresponding_merge)):
            merge = enter_corresponding_merge[merge_idx]
            switch = switch_corresponding_merge[merge_idx]
            enter_node = g[self.enters[merge_idx]]
            merge_node = g[merge]
            if switch is not None:
                switch_node = g[switch]
            else:
                switch_node = None
            merge_node.op = "get_tuple"
            merge_node.attr = {"index": merge_idx}
            # disconnect merge from switch
            # disconnect loopcond from switch
            disconnect_edge(g, enter_node.name, merge_node.name)
            if switch_node is not None:
                disconnect_edge(g, merge_node.name, switch_node.name)
                disconnect_edge(g, self.loopcond[0], switch_node.name)
            for i in merge_node.inputs[:]:
                disconnect_edge(g, i, merge_node.name)
            connect_edge(g, cond_body.name, merge_node.name)
            # delete get_tuple if it does nothing
            if len(merge_node.outputs) == 0:
                delete_node(g, merge)

        g[self.loopcond[0]].op = "return"

        # build the body function
        body = ParsedTFNode()
        body.op = "function_entry"
        body.name = tfssa._find_free_name("body_function_")
        body.inputs = []
        g[body.name] = body
        for switch_idx in range(len(switch_corresponding_merge)):
            switch = switch_corresponding_merge[switch_idx]
            exit = exit_corresponding_switch[switch_idx]
            disconnect_edge(g, switch, exit)

            # replace switch with a get_tuple
            switch_node = g[switch]
            switch_node.op = "get_tuple"
            switch_node.attr = {"index": switch_idx}
            connect_edge(g, body.name, switch_node.name)
            # delete get_tuple if it does nothing
            if len(switch_node.outputs) == 0:
                delete_node(g, switch)

        # replace all next_iteration with a single make_tuple
        # we replace merge with get_tuple and turn it into a function call
        # terminated with LoopCond
        make_outputs = ParsedTFNode()
        make_outputs.op = "make_tuple"
        make_outputs.name = tfssa._find_free_name("make_output_")
        g[make_outputs.name] = make_outputs
        for ni in merge_corresponding_ni:
            connect_edge(g, g[ni].inputs[0], make_outputs.name)

        # connect constant enters to come from function
        # connect constant enters to exit
        for idx, enter in enumerate(self.constant_enters):
            for output in list(g[enter].outputs):
                if output not in self.cond and output not in self.body:
                    cond_intersection = (
                        FindSubgraph(self.cond).visit(g, output).get_result()
                    )
                    body_intersection = (
                        FindSubgraph(self.body).visit(g, output).get_result()
                    )
                    if len(cond_intersection) > 0:
                        cond_intersection.append(output)
                        self.cond += cond_intersection
                    if len(body_intersection) > 0:
                        body_intersection.append(output)
                        self.body += body_intersection
                get_tuple = ParsedTFNode()
                get_tuple.op = "get_tuple"
                get_tuple.name = tfssa._find_free_name("get_tuple_const_")
                get_tuple.attr = {"index": idx + constant_base_index}
                g[get_tuple.name] = get_tuple

                if output in self.cond:
                    connect_edge(g, cond_body.name, get_tuple.name)
                elif output in self.body:
                    connect_edge(g, body.name, get_tuple.name)
                replace_source(g, enter, output, get_tuple.name)

            # body must accept and return everything
            get_tuple = ParsedTFNode()
            get_tuple.op = "get_tuple"
            get_tuple.name = tfssa._find_free_name("get_tuple_const_")
            get_tuple.attr = {"index": idx + constant_base_index}
            g[get_tuple.name] = get_tuple
            connect_edge(g, body.name, get_tuple.name)
            connect_edge(g, get_tuple.name, make_outputs.name)

        assert len(g[make_outputs.name].inputs) == len(g[make_inputs.name].inputs)

        output_return = ParsedTFNode()
        output_return.op = "return"
        output_return.name = tfssa._find_free_name("body_return_")
        g[output_return.name] = output_return
        connect_edge(g, make_outputs.name, output_return.name)
        while_loop.attr["cond_function"] = cond_body.name
        while_loop.attr["body_function"] = body.name
        for i in self.enters:
            delete_node(g, i)
        for i in self.next_iterations:
            delete_node(g, i)
        for i in self.constant_enters:
            delete_node(g, i)

        for i in range(len(exit_corresponding_switch)):
            exit_node = exit_corresponding_switch[i]
            g[exit_node].op = "get_tuple"
            g[exit_node].attr = {"index": i}
        cond_function = (
            FindSubgraph(self.loopcond[0]).visit(g, cond_body.name).get_result()
        )
        cond_function = set(cond_function + [self.loopcond[0], cond_body.name])
        body_function = (
            FindSubgraph(output_return.name).visit(g, body.name).get_result()
        )
        body_function = set(body_function + [body.name, output_return.name])

        # trace input constants associated with the cond_graph
        # and the body_graph. These constants can only have one consumer
        # for now. Any more and we will either need to associate
        # it as an argument, or split the constant.
        cond_constants = (
            FindImmediateUpstreamNodes(lambda x: x.op == "Const")
            .visit_many(g, cond_function)
            .get_result()
        )
        body_constants = (
            FindImmediateUpstreamNodes(lambda x: x.op == "Const")
            .visit_many(g, body_function)
            .get_result()
        )
        # for const_node in cond_constants + body_constants:
        #    assert(len(g[const_node].outputs) == 1)

        cond_function = cond_function.union(set(cond_constants))
        body_function = body_function.union(set(body_constants))

        downstream_cond = (
            FindAllReachableNodes(lambda x: True)
            .visit_many(g, cond_function)
            .get_result()
        )
        downstream_cond = set(downstream_cond) - cond_function
        if len(downstream_cond) > 0:
            logging.debug(
                "Disconnecting unused variables in condition function %s",
                downstream_cond,
            )
            for i in downstream_cond:
                delete_node(g, i)

        downstream_body = (
            FindAllReachableNodes(lambda x: True)
            .visit_many(g, body_function)
            .get_result()
        )
        downstream_body = set(downstream_body) - body_function
        if len(downstream_body) > 0:
            logging.debug(
                "Disconnecting unused variables in body function %s", downstream_body
            )
            for i in downstream_body:
                delete_node(g, i)

        cond_graph = {k: v for k, v in g.items() if k in cond_function}
        body_graph = {k: v for k, v in g.items() if k in body_function}
        g = {
            k: v
            for k, v in g.items()
            if k not in cond_function and k not in body_function
        }
        # localize control dependencies
        # In the main graph, reattach the control dependency to the while op
        for k, v in g.items():
            for idx in range(len(v.control_inputs)):
                if v.control_inputs[idx] not in g:
                    v.control_inputs[idx] = while_loop.name
                    while_loop.control_outputs.append(k)
            for idx in range(len(v.control_outputs)):
                if v.control_outputs[idx] not in g:
                    v.control_outputs[idx] = while_loop.name
                    while_loop.control_inputs.append(k)

        # in the cond and body graphs, drop non-local control dependencies
        # entirely
        for graph in [cond_graph, body_graph]:
            for k, v in graph.items():
                for idx in range(len(v.control_inputs) - 1, -1, -1):
                    if v.control_inputs[idx] not in graph:
                        v.control_inputs.pop(idx)

                for idx in range(len(v.control_outputs) - 1, -1, -1):
                    if v.control_outputs[idx] not in graph:
                        v.control_outputs.pop(idx)
        tfssa.functions[function_to_functionalize] = SSAFunction(g)
        tfssa.add_function(cond_body.name, SSAFunction(cond_graph))
        tfssa.add_function(body.name, SSAFunction(body_graph))
        return True