def draw_model_graph()

in tensorflow_lattice/python/visualization.py [0:0]


def draw_model_graph(model_graph,
                     calibrator_dpi=30,
                     calibrator_figsize=None,
                     image_format='png'):
  """Draws the model graph.

  This function requires IPython and graphviz packages.

  ```
  model_graph = estimators.get_model_graph(saved_model_path)
  visualization.draw_model_graph(model_graph)
  ```

  Args:
    model_graph: a `model_info.ModelInfo` objects to plot.
    calibrator_dpi: The DPI for calibrator plots inside the graph nodes.
    calibrator_figsize: The figsize parameter for calibrator plots.
    image_format: Format of the image to produce. Using 'svg' format can help
      with font rendering issues.
  """
  import graphviz  # pylint: disable=g-import-not-at-top

  dot = graphviz.Digraph(format=image_format, engine='dot')
  dot.graph_attr['ranksep'] = '0.75'

  # Check if we need split nodes for shared calibration
  model_has_shared_calibration = False
  for node in model_graph.nodes:
    model_has_shared_calibration |= (
        (isinstance(node, model_info.PWLCalibrationNode) or
         isinstance(node, model_info.CategoricalCalibrationNode)) and
        (len(_output_nodes(model_graph, node)) > 1))

  split_nodes = {}
  for node in model_graph.nodes:
    node_id = _node_id(node)
    if (isinstance(node, model_info.PWLCalibrationNode) or
        isinstance(node, model_info.CategoricalCalibrationNode)):
      # Add node for calibrator with calibrator plot inside.
      fig = plot_calibrator_nodes([node], figsize=calibrator_figsize)
      filename = os.path.join(tempfile.tempdir,
                              'i{}.{}'.format(node_id, image_format))
      plt.savefig(filename, dpi=calibrator_dpi)
      plt.close(fig)
      dot.node(node_id, '', image=filename, imagescale='true', shape='box')

      # Add input feature node.
      node_is_feature_calibration = isinstance(node.input_node,
                                               model_info.InputFeatureNode)
      if node_is_feature_calibration:
        input_node_id = node_id + 'input'
        dot.node(input_node_id, node.input_node.name)
        dot.edge(input_node_id + ':s', node_id + ':n')

        # Add split node for shared calibration.
        if model_has_shared_calibration:
          split_node_id = node_id + 'calibrated'
          split_node_name = 'calibrated {}'.format(node.input_node.name)
          dot.node(split_node_id, split_node_name)
          dot.edge(node_id + ':s', split_node_id + ':n')
          split_nodes[node_id] = (split_node_id, split_node_name)

    elif not isinstance(node, model_info.InputFeatureNode):
      dot.node(node_id, _node_name(node), shape='box', margin='0.3')

    if node is model_graph.output_node:
      output_node_id = node_id + 'output'
      dot.node(output_node_id, 'output')
      dot.edge(node_id + ':s', output_node_id)

  for node in model_graph.nodes:
    node_id = _node_id(node)
    for input_node in _input_nodes(node):
      if isinstance(input_node, model_info.InputFeatureNode):
        continue
      input_node_id = _node_id(input_node)
      if input_node_id in split_nodes:
        split_node_id, split_node_name = split_nodes[input_node_id]
        input_node_id = split_node_id + node_id
        dot.node(input_node_id, split_node_name)

      dot.edge(input_node_id + ':s', node_id)  # + ':n')

  filename = os.path.join(tempfile.tempdir, 'dot')
  try:
    image_path = dot.render(filename)
    _display(image_path=image_path, image_format=image_format)
  except graphviz.backend.ExecutableNotFound as e:
    if 'IPython.core.magics.namespace' in sys.modules:
      # Similar to Keras visualization lib, we don't raise an exception here to
      # avoid crashing notebooks during tests.
      print(
          'dot binaries were not found or not in PATH. The system running the '
          'colab binary might not have graphviz package installed: format({})'
          .format(e))
    else:
      raise e