def layers_topological_order()

in tensorwatch/model_graph/hiddenlayer/summary_graph.py [0:0]


    def layers_topological_order(self, recurrent=False):
        """
        Prepares an ordered list of layers to quantize sequentially. This list has all the layers ordered by their
        topological order in the graph.
        Args:
            recurrent (bool): indication on whether the model might have recurrent connections.
        """
        if self._layers_topological_order:
            return self._layers_topological_order
        adj_map = self.adjacency_map()
        ranked_ops = OrderedDict([(k, _OpRank(v, 0)) for k, v in adj_map.items()])

        def _recurrent_ancestor(ranked_ops_dict, dest_op_name, src_op_name):
            def _is_descendant(parent_op_name, dest_op_name):
                successors_names = [op.name for op in adj_map[parent_op_name].successors]
                if dest_op_name in successors_names:
                    return True
                for succ_name in successors_names:
                    if _is_descendant(succ_name, dest_op_name):
                        return True
                return False

            return _is_descendant(dest_op_name, src_op_name) and \
                   (0 < ranked_ops_dict[dest_op_name].rank < ranked_ops_dict[src_op_name].rank)

        def rank_op(ranked_ops_dict, op_name, rank):
            ranked_ops_dict[op_name].rank = rank
            for child_op in adj_map[op_name].successors:
                # In recurrent models: if a successor is also an ancestor - we don't increment its rank.
                if not recurrent or not _recurrent_ancestor(ranked_ops_dict, child_op.name, op_name):
                    rank_op(ranked_ops_dict, child_op.name, ranked_ops_dict[op_name].rank + 1)

        roots = [k for k, v in adj_map.items() if len(v.predecessors) == 0]
        for root_op_name in roots:
            rank_op(ranked_ops, root_op_name, 0)

        # Take only the modules from the original model
        module_dict = dict(self._src_model.named_modules())
        ret = sorted([k for k in ranked_ops.keys() if k in module_dict],
                     key=lambda k: ranked_ops[k].rank)
        # Check that only the actual roots have a rank of 0
        assert {k for k in ret if ranked_ops[k].rank == 0} <= set(roots)
        self._layers_topological_order = ret
        return ret