core/maxframe/core/graph/core.pyx (369 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from collections import deque from io import StringIO logger = logging.getLogger(__name__) cdef class DirectedGraph: cdef readonly: dict _nodes dict _predecessors dict _successors def __init__(self): self._nodes = dict() self._predecessors = dict() self._successors = dict() def __iter__(self): return iter(self._nodes) def __contains__(self, n): return n in self._nodes def __len__(self): return len(self._nodes) def __getitem__(self, n): return self._successors[n] def contains(self, node): return node in self._nodes def add_node(self, node, node_attr=None, **node_attrs): if node_attr is None: node_attr = node_attrs else: try: node_attr.update(node_attrs) except AttributeError: raise TypeError('The node_attr argument must be a dictionary') self._add_node(node, node_attr) cdef inline _add_node(self, node, dict node_attr=None): if node_attr is None: node_attr = dict() if node not in self._nodes: self._nodes[node] = node_attr self._successors[node] = dict() self._predecessors[node] = dict() else: self._nodes[node].update(node_attr) def remove_node(self, node): if node not in self._nodes: raise KeyError(f'Node {node} does not exist ' f'in the directed graph') del self._nodes[node] for succ in self._successors[node]: del self._predecessors[succ][node] del self._successors[node] for pred in self._predecessors[node]: del self._successors[pred][node] del self._predecessors[node] def add_edge(self, u, v, edge_attr=None, **edge_attrs): if edge_attr is None: edge_attr = edge_attrs else: try: edge_attr.update(edge_attrs) except AttributeError: raise TypeError('The edge_attr argument must be a dictionary') self._add_edge(u, v, edge_attr) cdef inline _add_edge(self, u, v, edge_attr=None): cdef: dict u_succ, v_pred if u not in self._nodes: raise KeyError(f'Node {u} does not exist in the directed graph') if v not in self._nodes: raise KeyError(f'Node {v} does not exist in the directed graph') if edge_attr is None: edge_attr = dict() u_succ = self._successors[u] if v in u_succ: u_succ[v].update(edge_attr) else: u_succ[v] = edge_attr v_pred = self._predecessors[v] if u not in v_pred: # `update` is not necessary, as they point to the same object v_pred[u] = edge_attr def remove_edge(self, u, v): try: del self._successors[u][v] del self._predecessors[v][u] except KeyError: raise KeyError(f'Edge {u}->{v} does not exist in the directed graph') def has_successor(self, u, v): return (u in self._successors) and (v in self._successors[u]) def has_predecessor(self, u, v): return (u in self._predecessors) and (v in self._predecessors[u]) def iter_nodes(self, data=False): if data: return iter(self._nodes.items()) return iter(self._nodes) def iter_successors(self, n): try: return iter(self._successors[n]) except KeyError: raise KeyError(f'Node {n} does not exist in the directed graph') cpdef list successors(self, n): try: return list(self._successors[n]) except KeyError: raise KeyError(f'Node {n} does not exist in the directed graph') def iter_predecessors(self, n): try: return iter(self._predecessors[n]) except KeyError: raise KeyError(f'Node {n} does not exist in the directed graph') cpdef list predecessors(self, n): try: return list(self._predecessors[n]) except KeyError: raise KeyError(f'Node {n} does not exist in the directed graph') cpdef int count_successors(self, n): return len(self._successors[n]) cpdef int count_predecessors(self, n): return len(self._predecessors[n]) def iter_indep(self, bint reverse=False): cdef dict preds preds = self._predecessors if not reverse else self._successors for n, p in preds.items(): if len(p) == 0: yield n cpdef int count_indep(self, reverse=False): cdef: dict preds int result = 0 preds = self._predecessors if not reverse else self._successors for n, p in preds.items(): if len(p) == 0: result += 1 return result def dfs(self, start=None, visit_predicate=None, successors=None, reverse=False): cdef: set visited = set() list stack bint visit_all = False if reverse: pred_fun, succ_fun = self.successors, self.predecessors else: pred_fun, succ_fun = self.predecessors, self.successors if start: if not isinstance(start, (list, tuple)): start = [start] stack = list(start) else: stack = list(self.iter_indep(reverse=reverse)) def _default_visit_predicate(n, visited): cdef list preds preds = pred_fun(n) return not preds or all(pred in visited for pred in preds) successors = successors or succ_fun visit_all = (visit_predicate == 'all') visit_predicate = visit_predicate or _default_visit_predicate while stack: node = stack.pop() if node in visited: continue preds = self.predecessors(node) if visit_all or visit_predicate(node, visited): yield node visited.add(node) stack.extend(n for n in successors(node) if n not in visited) else: stack.append(node) stack.extend(n for n in preds if n not in visited) def bfs(self, start=None, visit_predicate=None, successors=None, reverse=False): cdef: object queue object node set visited = set() bint visit_all = False if reverse: pred_fun, succ_fun = self.successors, self.predecessors else: pred_fun, succ_fun = self.predecessors, self.successors if start is not None: if not isinstance(start, (list, tuple)): start = [start] queue = deque(start) else: queue = deque(self.iter_indep(reverse=reverse)) def _default_visit_predicate(n, visited): preds = pred_fun(n) return not preds or all(pred in visited for pred in preds) successors = successors or succ_fun visit_all = (visit_predicate == 'all') visit_predicate = visit_predicate or _default_visit_predicate while queue: node = queue.popleft() if node in visited: continue preds = pred_fun(node) if visit_all or visit_predicate(node, visited): yield node visited.add(node) queue.extend(n for n in successors(node) if n not in visited) else: queue.append(node) queue.extend(n for n in preds if n not in visited) def copy(self): cdef DirectedGraph graph = type(self)() for n in self: if n not in graph._nodes: graph._add_node(n) for succ in self.iter_successors(n): if succ not in graph._nodes: graph._add_node(succ) graph._add_edge(n, succ) return graph def copyto(self, DirectedGraph other_graph): if other_graph is self: return other_graph._nodes = self._nodes.copy() other_graph._predecessors = self._predecessors.copy() other_graph._successors = self._successors.copy() def build_undirected(self): cdef DirectedGraph graph = DirectedGraph() for n in self: if n not in graph._nodes: graph._add_node(n) for succ in self._successors[n]: if succ not in graph._nodes: graph._add_node(succ) graph._add_edge(n, succ) graph._add_edge(succ, n) return graph def build_reversed(self): cdef DirectedGraph graph = type(self)() for n in self: if n not in graph._nodes: graph._add_node(n) for succ in self._successors[n]: if succ not in graph._nodes: graph._add_node(succ) graph._add_edge(succ, n) return graph @classmethod def _repr_in_dot(cls, val): if isinstance(val, bool): return 'true' if val else 'false' if isinstance(val, str): return f'"{val}"' return val def _extract_operators(self, node): return [node.op] def to_dot( self, graph_attrs=None, node_attrs=None, trunc_key=5, result_chunk_keys=None, show_columns=False): sio = StringIO() sio.write('digraph {\n') sio.write('splines=curved\n') sio.write('rankdir=BT\n') if graph_attrs: sio.write('graph [{0}];\n'.format( ' '.join(f'{k}={self._repr_in_dot(v)}' for k, v in graph_attrs.items()))) if node_attrs: sio.write('node [{0}];\n'.format( ' '.join(f'{k}={self._repr_in_dot(v)}' for k, v in node_attrs.items()))) chunk_style = '[shape=box]' operator_style = '[shape=circle]' visited = set() def get_col_names(obj): if hasattr(obj, "dtypes"): return f"\"{','.join(list(obj.dtypes.index))}\"" elif hasattr(obj, "name"): return f"\"{obj.name}\"" else: return "\"N/A\"" for node in self.iter_nodes(): for op in self._extract_operators(node): op_name = type(op).__name__ if op.stage is not None: op_name = f'{op_name}:{op.stage.name}' if op.key in visited: continue for input_chunk in (op.inputs or []): if input_chunk.key not in visited: sio.write(f'"Chunk:{self._gen_chunk_key(input_chunk, trunc_key)}" {chunk_style}\n') visited.add(input_chunk.key) if op.key not in visited: sio.write(f'"{op_name}:{op.key[:trunc_key]}_{id(op)}" {operator_style}\n') visited.add(op.key) sio.write(f'"Chunk:{self._gen_chunk_key(input_chunk, trunc_key)}" -> ' f'"{op_name}:{op.key[:trunc_key]}_{id(op)}"\n') for output_chunk in (op.outputs or []): if output_chunk.key not in visited: tmp_chunk_style = chunk_style if result_chunk_keys and output_chunk.key in result_chunk_keys: tmp_chunk_style = '[shape=box,style=filled,fillcolor=cadetblue1]' sio.write(f'"Chunk:{self._gen_chunk_key(output_chunk, trunc_key)}" {tmp_chunk_style}\n') visited.add(output_chunk.key) if op.key not in visited: sio.write(f'"{op_name}:{op.key[:trunc_key]}_{id(op)}" {operator_style}\n') visited.add(op.key) sio.write(f'"{op_name}:{op.key[:trunc_key]}_{id(op)}" -> ' f'"Chunk:{self._gen_chunk_key(output_chunk, trunc_key)}"') if show_columns: sio.write(f' [ label={get_col_names(output_chunk)} ]') sio.write("\n") sio.write('}') return sio.getvalue() @classmethod def _gen_chunk_key(cls, chunk, trunc_key): if "_" in chunk.key: key, index = chunk.key.split("_", 1) return "_".join([key[:trunc_key], index]) else: # pragma: no cover return chunk.key[:trunc_key] def _repr_svg_(self): # pragma: no cover from graphviz import Source return Source(self.to_dot())._repr_svg_() def _repr_mimebundle_(self, *args, **kw): # pragma: no cover from graphviz import Source return Source(self.to_dot())._repr_mimebundle_(*args, **kw) def compose(self, list keys=None): from ...optimizes.chunk_graph.fuse import Fusion return Fusion(self).compose(keys=keys) def decompose(self, nodes=None): from ...optimizes.chunk_graph.fuse import Fusion Fusion(self).decompose(nodes=nodes) def view(self, filename='default', graph_attrs=None, trunc_key=5, node_attrs=None, result_chunk_keys=None, show_columns=False): # pragma: no cover from graphviz import Source g = Source(self.to_dot(graph_attrs, node_attrs, trunc_key=trunc_key, result_chunk_keys=result_chunk_keys, show_columns=show_columns)) g.view(filename=filename, cleanup=True) def to_dag(self): dag = DAG() dag._nodes = self._nodes.copy() dag._predecessors = self._predecessors.copy() dag._successors = self._successors.copy() return dag class GraphContainsCycleError(Exception): pass cdef class DAG(DirectedGraph): def to_dag(self): return self def topological_iter(self, succ_checker=None, reverse=False): cdef: dict preds, succs set visited = set() list stack if len(self) == 0: return if reverse: preds, succs = self._successors, self._predecessors else: preds, succs = self._predecessors, self._successors # copy predecessors and successors succs = dict((k, set(v)) for k, v in succs.items()) preds = dict((k, set(v)) for k, v in preds.items()) def _default_succ_checker(_, predecessors): return len(predecessors) == 0 succ_checker = succ_checker or _default_succ_checker stack = list((p for p, l in preds.items() if len(l) == 0)) if not stack: raise GraphContainsCycleError while stack: node = stack.pop() yield node visited.add(node) for succ in succs.get(node, {}): if succ in visited: raise GraphContainsCycleError succ_preds = preds[succ] succ_preds.remove(node) if succ_checker(succ, succ_preds): stack.append(succ) if len(visited) != len(self): raise GraphContainsCycleError