in optimum/fx/parallelization/utils.py [0:0]
def stable_topological_sort(graph: Graph):
def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
args: List[torch.fx.node.Argument] = []
torch.fx.map_arg((n.args, n.kwargs), args.append)
return args
# Nodes are in exactly one of these three collections:
# - Nodes in `pending` are waiting to be processed (in reverse order):
pending = list(reversed(graph.nodes))
# - Nodes in `ready` have been processed and are already in the correct
# order.
ready = set()
# - `waiting` is a mapping from a dependency to nodes which depend on that
# dependency.
waiting = defaultdict(list)
# The cursor indicates the last processed node so we can add new nodes
# after it.
cursor = None
while pending:
node = pending.pop()
waiting_for = [x for x in _args(node) if x not in ready]
if waiting_for:
# We have unprocessed input nodes. Might as well wait for the last
# arg so an already sorted list will only recheck this node once.
waiting[waiting_for[-1]].append(node)
else:
ready.add(node)
if cursor and cursor.next is not node:
cursor.append(node)
cursor = node
# Mark the nodes that have been waiting for this node to finish as
# ready to check again.
pending.extend(reversed(waiting.pop(node, ())))
assert not waiting and len(ready) == len(graph.nodes)