def topological_sort()

in tinynn/converter/operators/graph.py [0:0]


    def topological_sort(self) -> typing.List[int]:
        """Sort the graph topologically

        Returns:
            typing.List[int]: The sorted indices of the nodes
        """

        # Emulating DFS with LifoQueue(stack)
        q = queue.LifoQueue()

        visited = set()
        indices = []

        # We push all inputs nodes to the target queue.
        inputs = [v for v in self.graph.vs if v['node_type'] == ExtendedOperator.INPUT_NODE]
        other_input_nodes = [v for v in self.graph.vs if v['node_type'] >= 0 and v.indegree() == 0]

        # Constants are all known, so just marking them here.
        constants = [v for v in self.graph.vs if v['node_type'] == ExtendedOperator.CONSTANT_NODE]
        for c in constants:
            indices.append(c.index)
            visited.add(c.index)
            for e in c.out_edges():
                v = e.target_vertex
                if v not in other_input_nodes:
                    skip = False
                    for e in v.in_edges():
                        if e.source not in visited:
                            skip = True
                            break

                    if skip:
                        continue

                    if v['node_type'] >= 0:
                        other_input_nodes.append(v)
                    else:
                        if v['node_type'] != ExtendedOperator.OUTPUT_NODE:
                            type_name = ExtendedOperator(v['node_type']).type_name()
                            log.warning(
                                f'The child node of a constant node is of type {type_name}, which is unexpected'
                            )

        for v in other_input_nodes:
            if v['node_type'] not in (
                ExtendedOperator.ASSIGN_VARIABLE,
                ExtendedOperator.READ_VARIABLE,
                ExtendedOperator.RANDOM_STANDARD_NORMAL,
                ExtendedOperator.MULTINOMIAL,
                ExtendedOperator.RANDOM_UNIFORM,
            ):
                output_name = v['outputs'][0]
                type_name = v['op'].type_name()
                log.warning(f'{type_name}({output_name}) is an orphaned node, which is unexpected')

        for i in reversed(inputs + other_input_nodes):
            q.put(i)

        while not q.empty():
            v = q.get()

            # Skip if already visited
            if v.index in visited:
                continue

            # Ensure all input nodes are visited
            skip = False
            for e in v.in_edges():
                if e.source not in visited:
                    skip = True
                    break

            if skip:
                continue

            # Mark visited if the previous constraints are met
            visited.add(v.index)
            indices.append(v.index)

            # Push the out nodes to the target queue
            for e in reversed(v.out_edges()):
                q.put(e.target_vertex)

        return indices