in tensorflow_lattice/python/visualization.py [0:0]
def draw_model_graph(model_graph,
calibrator_dpi=30,
calibrator_figsize=None,
image_format='png'):
"""Draws the model graph.
This function requires IPython and graphviz packages.
```
model_graph = estimators.get_model_graph(saved_model_path)
visualization.draw_model_graph(model_graph)
```
Args:
model_graph: a `model_info.ModelInfo` objects to plot.
calibrator_dpi: The DPI for calibrator plots inside the graph nodes.
calibrator_figsize: The figsize parameter for calibrator plots.
image_format: Format of the image to produce. Using 'svg' format can help
with font rendering issues.
"""
import graphviz # pylint: disable=g-import-not-at-top
dot = graphviz.Digraph(format=image_format, engine='dot')
dot.graph_attr['ranksep'] = '0.75'
# Check if we need split nodes for shared calibration
model_has_shared_calibration = False
for node in model_graph.nodes:
model_has_shared_calibration |= (
(isinstance(node, model_info.PWLCalibrationNode) or
isinstance(node, model_info.CategoricalCalibrationNode)) and
(len(_output_nodes(model_graph, node)) > 1))
split_nodes = {}
for node in model_graph.nodes:
node_id = _node_id(node)
if (isinstance(node, model_info.PWLCalibrationNode) or
isinstance(node, model_info.CategoricalCalibrationNode)):
# Add node for calibrator with calibrator plot inside.
fig = plot_calibrator_nodes([node], figsize=calibrator_figsize)
filename = os.path.join(tempfile.tempdir,
'i{}.{}'.format(node_id, image_format))
plt.savefig(filename, dpi=calibrator_dpi)
plt.close(fig)
dot.node(node_id, '', image=filename, imagescale='true', shape='box')
# Add input feature node.
node_is_feature_calibration = isinstance(node.input_node,
model_info.InputFeatureNode)
if node_is_feature_calibration:
input_node_id = node_id + 'input'
dot.node(input_node_id, node.input_node.name)
dot.edge(input_node_id + ':s', node_id + ':n')
# Add split node for shared calibration.
if model_has_shared_calibration:
split_node_id = node_id + 'calibrated'
split_node_name = 'calibrated {}'.format(node.input_node.name)
dot.node(split_node_id, split_node_name)
dot.edge(node_id + ':s', split_node_id + ':n')
split_nodes[node_id] = (split_node_id, split_node_name)
elif not isinstance(node, model_info.InputFeatureNode):
dot.node(node_id, _node_name(node), shape='box', margin='0.3')
if node is model_graph.output_node:
output_node_id = node_id + 'output'
dot.node(output_node_id, 'output')
dot.edge(node_id + ':s', output_node_id)
for node in model_graph.nodes:
node_id = _node_id(node)
for input_node in _input_nodes(node):
if isinstance(input_node, model_info.InputFeatureNode):
continue
input_node_id = _node_id(input_node)
if input_node_id in split_nodes:
split_node_id, split_node_name = split_nodes[input_node_id]
input_node_id = split_node_id + node_id
dot.node(input_node_id, split_node_name)
dot.edge(input_node_id + ':s', node_id) # + ':n')
filename = os.path.join(tempfile.tempdir, 'dot')
try:
image_path = dot.render(filename)
_display(image_path=image_path, image_format=image_format)
except graphviz.backend.ExecutableNotFound as e:
if 'IPython.core.magics.namespace' in sys.modules:
# Similar to Keras visualization lib, we don't raise an exception here to
# avoid crashing notebooks during tests.
print(
'dot binaries were not found or not in PATH. The system running the '
'colab binary might not have graphviz package installed: format({})'
.format(e))
else:
raise e