scripts/staging/fedplanner/graph.py (158 lines of code) (raw):

# ------------------------------------------------------------- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # # ------------------------------------------------------------- import sys import re import networkx as nx import matplotlib.pyplot as plt try: import pygraphviz from networkx.drawing.nx_agraph import graphviz_layout HAS_PYGRAPHVIZ = True except ImportError: HAS_PYGRAPHVIZ = False print("[WARNING] pygraphviz not found. Please install via 'pip install pygraphviz'.\n" "If not installed, we will use an alternative layout (spring_layout).") def parse_line(line: str): """ Parse a single line from the trace file to extract: - Node ID - Operation (hop name) - Kind (e.g., FOUT, LOUT, NREF) - Total cost - Weight - Refs (list of IDs that this node depends on) """ # 1) Match a node ID in the form of "(R)" or "(<number>)" match_id = re.match(r'^\((R|\d+)\)', line) if not match_id: return None node_id = match_id.group(1) # 2) The remaining string after the node ID after_id = line[match_id.end():].strip() # Extract operation (hop name) before the first "[" match_label = re.search(r'^(.*?)\s*\[', after_id) if match_label: operation = match_label.group(1).strip() else: operation = after_id.strip() # 3) Extract the kind (content inside the first pair of brackets "[]") match_bracket = re.search(r'\[([^\]]+)\]', after_id) if match_bracket: kind = match_bracket.group(1).strip() else: kind = "" # 4) Extract total and weight from the content inside curly braces "{}" total = "" weight = "" match_curly = re.search(r'\{([^}]+)\}', line) if match_curly: curly_content = match_curly.group(1) m_total = re.search(r'Total:\s*([\d\.]+)', curly_content) m_weight = re.search(r'Weight:\s*([\d\.]+)', curly_content) if m_total: total = m_total.group(1) if m_weight: weight = m_weight.group(1) # 5) Extract reference nodes: look for the first parenthesis containing numbers after the hop name match_refs = re.search(r'\(\s*(\d+(?:,\d+)*)\s*\)', after_id) if match_refs: ref_str = match_refs.group(1) refs = [r.strip() for r in ref_str.split(',') if r.strip().isdigit()] else: refs = [] return { 'node_id': node_id, 'operation': operation, 'kind': kind, 'total': total, 'weight': weight, 'refs': refs } def build_dag_from_file(filename: str): """ Read a trace file line by line and build a directed acyclic graph (DAG) using NetworkX. """ G = nx.DiGraph() with open(filename, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue info = parse_line(line) if not info: continue node_id = info['node_id'] operation = info['operation'] kind = info['kind'] total = info['total'] weight = info['weight'] refs = info['refs'] # Add node with attributes G.add_node(node_id, label=operation, kind=kind, total=total, weight=weight) # Add edges from references to this node for r in refs: if r not in G: G.add_node(r, label=r, kind="", total="", weight="") G.add_edge(r, node_id) return G def main(): """ Main function that: - Reads a filename from command-line arguments - Builds a DAG from the file - Draws and displays the DAG using matplotlib """ # Get filename from command-line argument if len(sys.argv) < 2: print("[ERROR] No filename provided.\nUsage: python plot_federated_dag.py <filename>") sys.exit(1) filename = sys.argv[1] print(f"[INFO] Running with filename '{filename}'") # Build the DAG G = build_dag_from_file(filename) # Print debug info: nodes and edges print("Nodes:", G.nodes(data=True)) print("Edges:", list(G.edges())) # Decide on layout if HAS_PYGRAPHVIZ: # graphviz_layout with rankdir=BT (bottom to top), etc. pos = graphviz_layout(G, prog='dot', args='-Grankdir=BT -Gnodesep=0.5 -Granksep=0.8') else: # Fallback layout if pygraphviz is not installed pos = nx.spring_layout(G, seed=42) # Dynamically adjust figure size based on number of nodes node_count = len(G.nodes()) fig_width = 10 + node_count / 10.0 fig_height = 6 + node_count / 10.0 plt.figure(figsize=(fig_width, fig_height), facecolor='white', dpi=300) ax = plt.gca() ax.set_facecolor('white') # Generate labels for each node in the format: # node_id: operation_name # C<total> (W<weight>) labels = { n: f"{n}: {G.nodes[n].get('label', n)}\n C{G.nodes[n].get('total', '')} (W{G.nodes[n].get('weight', '')})" for n in G.nodes() } # Function to determine color based on 'kind' def get_color(n): k = G.nodes[n].get('kind', '').lower() if k == 'fout': return 'tomato' elif k == 'lout': return 'dodgerblue' elif k == 'nref': return 'mediumpurple' else: return 'mediumseagreen' # Determine node shapes based on operation name: # - '^' (triangle) if the label contains "twrite" # - 's' (square) if the label contains "tread" # - 'o' (circle) otherwise triangle_nodes = [n for n in G.nodes() if 'twrite' in G.nodes[n].get('label', '').lower()] square_nodes = [n for n in G.nodes() if 'tread' in G.nodes[n].get('label', '').lower()] other_nodes = [ n for n in G.nodes() if 'twrite' not in G.nodes[n].get('label', '').lower() and 'tread' not in G.nodes[n].get('label', '').lower() ] # Colors for each group triangle_colors = [get_color(n) for n in triangle_nodes] square_colors = [get_color(n) for n in square_nodes] other_colors = [get_color(n) for n in other_nodes] # Draw nodes group-wise node_collection_triangle = nx.draw_networkx_nodes( G, pos, nodelist=triangle_nodes, node_size=800, node_color=triangle_colors, node_shape='^', ax=ax ) node_collection_square = nx.draw_networkx_nodes( G, pos, nodelist=square_nodes, node_size=800, node_color=square_colors, node_shape='s', ax=ax ) node_collection_other = nx.draw_networkx_nodes( G, pos, nodelist=other_nodes, node_size=800, node_color=other_colors, node_shape='o', ax=ax ) # Set z-order for nodes, edges, and labels node_collection_triangle.set_zorder(1) node_collection_square.set_zorder(1) node_collection_other.set_zorder(1) edge_collection = nx.draw_networkx_edges(G, pos, arrows=True, arrowstyle='->', ax=ax) if isinstance(edge_collection, list): for ec in edge_collection: ec.set_zorder(2) else: edge_collection.set_zorder(2) label_dict = nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, ax=ax) for text in label_dict.values(): text.set_zorder(3) # Set the title plt.title("Program Level Federated Plan", fontsize=14, fontweight="bold") # Provide a small legend on the top-right or top-left plt.text(1, 1, "[LABEL]\n hopID: hopName\n C(Total) (W(Weight))", fontsize=12, ha='right', va='top', transform=ax.transAxes) # Example mini-legend for different 'kind' values plt.scatter(0.05, 0.95, color='dodgerblue', s=200, transform=ax.transAxes) plt.scatter(0.18, 0.95, color='tomato', s=200, transform=ax.transAxes) plt.scatter(0.31, 0.95, color='mediumpurple', s=200, transform=ax.transAxes) plt.text(0.08, 0.95, "LOUT", fontsize=12, va='center', transform=ax.transAxes) plt.text(0.21, 0.95, "FOUT", fontsize=12, va='center', transform=ax.transAxes) plt.text(0.34, 0.95, "NREF", fontsize=12, va='center', transform=ax.transAxes) plt.axis("off") # Save the plot to a file with the same name as the input file, but with a .png extension output_filename = f"{filename.rsplit('.', 1)[0]}.png" plt.savefig(output_filename, format='png', dpi=300, bbox_inches='tight') plt.show() if __name__ == '__main__': main()