def plot_artifact_lineage()

in tfx_airflow/notebooks/utils.py [0:0]


  def plot_artifact_lineage(self, g):
    """Plots a `nx.DiGraph` object.

    This method uses networkx and matplotlib to plot the graph.
    The nodes are places from left to right w.r.t. its depth.
    Nodes at the same depths are placed vertically.
    Artifact is shown in green, and Execution is shown in red.
    Nodes are positioned in a bipartite graph layout.

    Args:
      g: A `nx.DiGraph` object.
    """
    # make a copy of the graph; add auxiliary nodes
    dag = g.copy(as_view=False)
    label_anchor_id = 10000
    for node_id in g.nodes:
      if node_id > 0:
        dag.add_node(label_anchor_id + node_id)
      else:
        dag.add_node(node_id - label_anchor_id)

    # assign node color and label
    node_color = ''
    node_labels = {}
    for node_id in dag.nodes:
      if node_id > 0 and node_id < label_anchor_id:
        node_color += 'c'
        node_labels[node_id] = abs(node_id)
      elif node_id > 0 and node_id >= label_anchor_id:
        # artifact label
        node_color += 'w'
        type_name = dag.nodes[node_id - label_anchor_id]['_label_']
        type_segments = re.split('([A-Z][a-z]+)', type_name)
        node_txt = ('\n').join([s for s in type_segments if s])
        node_labels[node_id] = node_txt
      elif node_id < 0 and node_id > -1 * label_anchor_id:
        node_color += 'm'
        node_labels[node_id] = abs(node_id)
      else:
        # execution label
        node_color += 'w'
        type_name = dag.nodes[node_id + label_anchor_id]['_label_']
        node_txt = type_name.split('.')[-1]
        node_labels[node_id] = node_txt
    pos = {}
    a_nodes = []
    e_nodes = []
    for node_id in dag.nodes:
      if node_id > 0 and node_id < label_anchor_id:
        a_nodes.append(node_id)
      elif node_id < 0 and node_id > -1 * label_anchor_id:
        e_nodes.append(node_id)

    # assign edge color
    edge_color = []
    for (_, _, labels) in dag.edges(data=True):
      edge_color.append('y' if labels['is_cached'] else 'k')

    a_nodes.sort(key=abs)
    e_nodes.sort(key=abs)
    a_node_y = 0
    e_node_y = 0.035
    a_offset = -0.5 if len(a_nodes) % 2 == 0 else 0
    e_offset = -0.5 if len(e_nodes) % 2 == 0 else 0
    a_node_x_min = -1 * len(a_nodes)/2 + a_offset
    e_node_x_min = -1 * len(e_nodes)/2 + e_offset
    a_node_x = a_node_x_min
    e_node_x = e_node_x_min
    node_step = 1
    for a_id in a_nodes:
      pos[a_id] = [a_node_x, a_node_y]
      pos[a_id + label_anchor_id] = [a_node_x, a_node_y - 0.01]
      a_node_x += node_step
    for e_id in e_nodes:
      pos[e_id] = [e_node_x, e_node_y]
      pos[e_id - label_anchor_id] = [e_node_x, e_node_y + 0.01]
      e_node_x += node_step

    nx.draw(dag, pos=pos,
            node_size=500, node_color=node_color,
            labels=node_labels, node_shape='o', font_size=8.3, label='abc',
            width=0.5, edge_color=edge_color)

    a_bbox_props = dict(boxstyle='square,pad=0.3', fc='c', ec='b', lw=0)
    plt.annotate('  Artifacts  ',
                 xycoords='axes fraction', xy=(0.85, 0.575),
                 textcoords='axes fraction', xytext=(0.85, 0.575),
                 bbox=a_bbox_props, alpha=0.6)
    e_bbox_props = dict(boxstyle='square,pad=0.3', fc='m', ec='b', lw=0)
    plt.annotate('Executions',
                 xycoords='axes fraction', xy=(0.85, 0.5),
                 textcoords='axes fraction', xytext=(0.85, 0.5),
                 bbox=e_bbox_props, alpha=0.6)
    plt.annotate('  Cached    ',
                 xycoords='axes fraction', xy=(0.85, 0.425),
                 textcoords='axes fraction', xytext=(0.85, 0.425),
                 alpha=0.6)
    plt.annotate('', xycoords='axes fraction', xy=(0.975, 0.405),
                 textcoords='axes fraction', xytext=(0.845, 0.405),
                 arrowprops=dict(edgecolor='y', arrowstyle='->', alpha=0.6))

    x_lim_left = min(a_node_x_min, e_node_x_min) - 0.5
    x_lim_right = min(1 - 0.05 * len(a_nodes), max(a_node_x, e_node_x))

    x_lim_left = max(-2 - 1.5/len(a_nodes),
                     min(a_node_x_min, e_node_x_min) - 1.0)
    x_lim_right = max(a_node_x, e_node_x) + 0.1
    plt.xlim(x_lim_left, x_lim_right)

    plt.show()