in tfx_airflow/notebooks/utils.py [0:0]
def plot_artifact_lineage(self, g):
"""Plots a `nx.DiGraph` object.
This method uses networkx and matplotlib to plot the graph.
The nodes are places from left to right w.r.t. its depth.
Nodes at the same depths are placed vertically.
Artifact is shown in green, and Execution is shown in red.
Nodes are positioned in a bipartite graph layout.
Args:
g: A `nx.DiGraph` object.
"""
# make a copy of the graph; add auxiliary nodes
dag = g.copy(as_view=False)
label_anchor_id = 10000
for node_id in g.nodes:
if node_id > 0:
dag.add_node(label_anchor_id + node_id)
else:
dag.add_node(node_id - label_anchor_id)
# assign node color and label
node_color = ''
node_labels = {}
for node_id in dag.nodes:
if node_id > 0 and node_id < label_anchor_id:
node_color += 'c'
node_labels[node_id] = abs(node_id)
elif node_id > 0 and node_id >= label_anchor_id:
# artifact label
node_color += 'w'
type_name = dag.nodes[node_id - label_anchor_id]['_label_']
type_segments = re.split('([A-Z][a-z]+)', type_name)
node_txt = ('\n').join([s for s in type_segments if s])
node_labels[node_id] = node_txt
elif node_id < 0 and node_id > -1 * label_anchor_id:
node_color += 'm'
node_labels[node_id] = abs(node_id)
else:
# execution label
node_color += 'w'
type_name = dag.nodes[node_id + label_anchor_id]['_label_']
node_txt = type_name.split('.')[-1]
node_labels[node_id] = node_txt
pos = {}
a_nodes = []
e_nodes = []
for node_id in dag.nodes:
if node_id > 0 and node_id < label_anchor_id:
a_nodes.append(node_id)
elif node_id < 0 and node_id > -1 * label_anchor_id:
e_nodes.append(node_id)
# assign edge color
edge_color = []
for (_, _, labels) in dag.edges(data=True):
edge_color.append('y' if labels['is_cached'] else 'k')
a_nodes.sort(key=abs)
e_nodes.sort(key=abs)
a_node_y = 0
e_node_y = 0.035
a_offset = -0.5 if len(a_nodes) % 2 == 0 else 0
e_offset = -0.5 if len(e_nodes) % 2 == 0 else 0
a_node_x_min = -1 * len(a_nodes)/2 + a_offset
e_node_x_min = -1 * len(e_nodes)/2 + e_offset
a_node_x = a_node_x_min
e_node_x = e_node_x_min
node_step = 1
for a_id in a_nodes:
pos[a_id] = [a_node_x, a_node_y]
pos[a_id + label_anchor_id] = [a_node_x, a_node_y - 0.01]
a_node_x += node_step
for e_id in e_nodes:
pos[e_id] = [e_node_x, e_node_y]
pos[e_id - label_anchor_id] = [e_node_x, e_node_y + 0.01]
e_node_x += node_step
nx.draw(dag, pos=pos,
node_size=500, node_color=node_color,
labels=node_labels, node_shape='o', font_size=8.3, label='abc',
width=0.5, edge_color=edge_color)
a_bbox_props = dict(boxstyle='square,pad=0.3', fc='c', ec='b', lw=0)
plt.annotate(' Artifacts ',
xycoords='axes fraction', xy=(0.85, 0.575),
textcoords='axes fraction', xytext=(0.85, 0.575),
bbox=a_bbox_props, alpha=0.6)
e_bbox_props = dict(boxstyle='square,pad=0.3', fc='m', ec='b', lw=0)
plt.annotate('Executions',
xycoords='axes fraction', xy=(0.85, 0.5),
textcoords='axes fraction', xytext=(0.85, 0.5),
bbox=e_bbox_props, alpha=0.6)
plt.annotate(' Cached ',
xycoords='axes fraction', xy=(0.85, 0.425),
textcoords='axes fraction', xytext=(0.85, 0.425),
alpha=0.6)
plt.annotate('', xycoords='axes fraction', xy=(0.975, 0.405),
textcoords='axes fraction', xytext=(0.845, 0.405),
arrowprops=dict(edgecolor='y', arrowstyle='->', alpha=0.6))
x_lim_left = min(a_node_x_min, e_node_x_min) - 0.5
x_lim_right = min(1 - 0.05 * len(a_nodes), max(a_node_x, e_node_x))
x_lim_left = max(-2 - 1.5/len(a_nodes),
min(a_node_x_min, e_node_x_min) - 1.0)
x_lim_right = max(a_node_x, e_node_x) + 0.1
plt.xlim(x_lim_left, x_lim_right)
plt.show()