in core/maxframe/core/graph/core.pyx [0:0]
def to_dot(
self,
graph_attrs=None,
node_attrs=None,
trunc_key=5, result_chunk_keys=None, show_columns=False):
sio = StringIO()
sio.write('digraph {\n')
sio.write('splines=curved\n')
sio.write('rankdir=BT\n')
if graph_attrs:
sio.write('graph [{0}];\n'.format(
' '.join(f'{k}={self._repr_in_dot(v)}' for k, v in graph_attrs.items())))
if node_attrs:
sio.write('node [{0}];\n'.format(
' '.join(f'{k}={self._repr_in_dot(v)}' for k, v in node_attrs.items())))
chunk_style = '[shape=box]'
operator_style = '[shape=circle]'
visited = set()
def get_col_names(obj):
if hasattr(obj, "dtypes"):
return f"\"{','.join(list(obj.dtypes.index))}\""
elif hasattr(obj, "name"):
return f"\"{obj.name}\""
else:
return "\"N/A\""
for node in self.iter_nodes():
for op in self._extract_operators(node):
op_name = type(op).__name__
if op.stage is not None:
op_name = f'{op_name}:{op.stage.name}'
if op.key in visited:
continue
for input_chunk in (op.inputs or []):
if input_chunk.key not in visited:
sio.write(f'"Chunk:{self._gen_chunk_key(input_chunk, trunc_key)}" {chunk_style}\n')
visited.add(input_chunk.key)
if op.key not in visited:
sio.write(f'"{op_name}:{op.key[:trunc_key]}_{id(op)}" {operator_style}\n')
visited.add(op.key)
sio.write(f'"Chunk:{self._gen_chunk_key(input_chunk, trunc_key)}" -> '
f'"{op_name}:{op.key[:trunc_key]}_{id(op)}"\n')
for output_chunk in (op.outputs or []):
if output_chunk.key not in visited:
tmp_chunk_style = chunk_style
if result_chunk_keys and output_chunk.key in result_chunk_keys:
tmp_chunk_style = '[shape=box,style=filled,fillcolor=cadetblue1]'
sio.write(f'"Chunk:{self._gen_chunk_key(output_chunk, trunc_key)}" {tmp_chunk_style}\n')
visited.add(output_chunk.key)
if op.key not in visited:
sio.write(f'"{op_name}:{op.key[:trunc_key]}_{id(op)}" {operator_style}\n')
visited.add(op.key)
sio.write(f'"{op_name}:{op.key[:trunc_key]}_{id(op)}" -> '
f'"Chunk:{self._gen_chunk_key(output_chunk, trunc_key)}"')
if show_columns:
sio.write(f' [ label={get_col_names(output_chunk)} ]')
sio.write("\n")
sio.write('}')
return sio.getvalue()