perl-package/AI-MXNet/lib/AI/MXNet/Visualization.pm (309 lines of code) (raw):
package AI::MXNet::Visualization;
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::Function::Parameters;
use JSON::PP;
=encoding UTF-8
=head1 NAME
AI::MXNet::Vizualization - Vizualization support for Perl interface to MXNet machine learning library
=head1 SYNOPSIS
use strict;
use warnings;
use AI::MXNet qw(mx);
### model
my $data = mx->symbol->Variable('data');
my $conv1= mx->symbol->Convolution(data => $data, name => 'conv1', num_filter => 32, kernel => [3,3], stride => [2,2]);
my $bn1 = mx->symbol->BatchNorm(data => $conv1, name => "bn1");
my $act1 = mx->symbol->Activation(data => $bn1, name => 'relu1', act_type => "relu");
my $mp1 = mx->symbol->Pooling(data => $act1, name => 'mp1', kernel => [2,2], stride =>[2,2], pool_type=>'max');
my $conv2= mx->symbol->Convolution(data => $mp1, name => 'conv2', num_filter => 32, kernel=>[3,3], stride=>[2,2]);
my $bn2 = mx->symbol->BatchNorm(data => $conv2, name=>"bn2");
my $act2 = mx->symbol->Activation(data => $bn2, name=>'relu2', act_type=>"relu");
my $mp2 = mx->symbol->Pooling(data => $act2, name => 'mp2', kernel=>[2,2], stride=>[2,2], pool_type=>'max');
my $fl = mx->symbol->Flatten(data => $mp2, name=>"flatten");
my $fc1 = mx->symbol->FullyConnected(data => $fl, name=>"fc1", num_hidden=>30);
my $act3 = mx->symbol->Activation(data => $fc1, name=>'relu3', act_type=>"relu");
my $fc2 = mx->symbol->FullyConnected(data => $act3, name=>'fc2', num_hidden=>10);
my $softmax = mx->symbol->SoftmaxOutput(data => $fc2, name => 'softmax');
## creates the image file working directory
mx->viz->plot_network($softmax, save_format => 'png')->render("network.png");
=head1 DESCRIPTION
Vizualization support for Perl interface to MXNet machine learning library
=head1 Class methods
=head2 print_summary
convert symbol for detail information
Parameters
----------
symbol: AI::MXNet::Symbol
symbol to be visualized
shape: hashref
hashref of shapes, str->shape (arrayref[int]), given input shapes
line_length: int
total length of printed lines
positions: arrayref[float]
relative or absolute positions of log elements in each line
Returns
------
nothing
=cut
method print_summary(
AI::MXNet::Symbol $symbol,
Maybe[HashRef[Shape]] $shape=,
Int $line_length=120,
ArrayRef[Num] $positions=[.44, .64, .74, 1]
)
{
my $show_shape;
my %shape_dict;
if(defined $shape)
{
$show_shape = 1;
my $interals = $symbol->get_internals;
my (undef, $out_shapes, undef) = $interals->infer_shape(%{ $shape });
Carp::confess("Input shape is incomplete")
unless defined $out_shapes;
@shape_dict{ @{ $interals->list_outputs } } = @{ $out_shapes };
}
my $conf = decode_json($symbol->tojson);
my $nodes = $conf->{nodes};
my %heads = map { $_ => 1 } @{ $conf->{heads}[0] };
if($positions->[-1] <= 1)
{
$positions = [map { int($line_length * $_) } @{ $positions }];
}
# header names for the different log elements
my $to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Previous Layer'];
my $print_row = sub { my ($fields, $positions) = @_;
my $line = '';
enumerate(sub {
my ($i, $field) = @_;
$line .= $field//'';
$line = substr($line, 0, $positions->[$i]);
$line .= ' ' x ($positions->[$i] - length($line));
}, $fields);
print $line,"\n";
};
print('_' x $line_length,"\n");
$print_row->($to_display, $positions);
print('=' x $line_length,"\n");
my $print_layer_summary = sub { my ($node, $out_shape) = @_;
my $op = $node->{op};
my $pre_node = [];
my $pre_filter = 0;
if($op ne 'null')
{
my $inputs = $node->{inputs};
for my $item (@{ $inputs })
{
my $input_node = $nodes->[$item->[0]];
my $input_name = $input_node->{name};
if($input_node->{op} ne 'null' or exists $heads{ $item->[0] })
{
push @{ $pre_node }, $input_name;
if($show_shape)
{
my $key = $input_name;
$key .= '_output' if $input_node->{op} ne 'null';
if(exists $shape_dict{ $key })
{
$pre_filter = $pre_filter + int($shape_dict{$key}[1]//0);
}
}
}
}
}
my $cur_param = 0;
if($op eq 'Convolution')
{
my $num_filter = $node->{attr}{num_filter};
$cur_param = $pre_filter * $num_filter;
while($node->{attr}{kernel} =~ /(\d+)/g)
{
$cur_param *= $1;
}
$cur_param += $num_filter;
}
elsif($op eq 'FullyConnected')
{
$cur_param = $pre_filter * ($node->{attr}{num_hidden} + 1);
}
elsif($op eq 'BatchNorm')
{
my $key = "$node->{name}_output";
if($show_shape)
{
my $num_filter = $shape_dict{$key}[1];
$cur_param = $num_filter * 2;
}
}
my $first_connection;
if(not $pre_node)
{
$first_connection = '';
}
else
{
$first_connection = $pre_node->[0];
}
my $fields = [
$node->{name} . '(' . $op . ')',
join('x', @{ $out_shape }),
$cur_param,
$first_connection
];
$print_row->($fields, $positions);
if(@{ $pre_node } > 1)
{
for my $i (1..@{ $pre_node }-1)
{
$fields = ['', '', '', $pre_node->[$i]];
$print_row->($fields, $positions);
}
}
return $cur_param;
};
my $total_params = 0;
enumerate(sub {
my ($i, $node) = @_;
my $out_shape = [];
my $op = $node->{op};
return if($op eq 'null' and $i > 0);
if($op ne 'null' or exists $heads{$i})
{
if($show_shape)
{
my $key = $node->{name};
$key .= '_output' if $op ne 'null';
if(exists $shape_dict{ $key })
{
my $end = @{ $shape_dict{ $key } };
@{ $out_shape } = @{ $shape_dict{ $key } }[1..$end-1];
}
}
}
$total_params += $print_layer_summary->($nodes->[$i], $out_shape);
if($i == @{ $nodes } - 1)
{
print('=' x $line_length, "\n");
}
else
{
print('_' x $line_length, "\n");
}
}, $nodes);
print("Total params: $total_params\n");
print('_' x $line_length, "\n");
}
=head2 plot_network
convert symbol to dot object for visualization
Parameters
----------
title: str
title of the dot graph
symbol: AI::MXNet::Symbol
symbol to be visualized
shape: HashRef[Shape]
If supplied, the visualization will include the shape
of each tensor on the edges between nodes.
node_attrs: HashRef of node's attributes
for example:
{shape => "oval",fixedsize => "false"}
means to plot the network in "oval"
hide_weights: Bool
if True (default) then inputs with names like `*_weight`
or `*_bias` will be hidden
Returns
------
dot: Diagraph
dot object of symbol
=cut
method plot_network(
AI::MXNet::Symbol $symbol,
Str :$title='plot',
Str :$save_format='ps',
Maybe[HashRef[Shape]] :$shape=,
HashRef[Str] :$node_attrs={},
Bool :$hide_weights=1
)
{
eval { require GraphViz; };
Carp::confess("plot_network requires GraphViz module") if $@;
my $draw_shape;
my %shape_dict;
if(defined $shape)
{
$draw_shape = 1;
my $interals = $symbol->get_internals;
my (undef, $out_shapes, undef) = $interals->infer_shape(%{ $shape });
Carp::confess("Input shape is incomplete")
unless defined $out_shapes;
@shape_dict{ @{ $interals->list_outputs } } = @{ $out_shapes };
}
my $conf = decode_json($symbol->tojson);
my $nodes = $conf->{nodes};
my %node_attr = (
qw/ shape box fixedsize true
width 1.3 height 0.8034 style filled/,
%{ $node_attrs }
);
my $dot = AI::MXNet::Visualization::PythonGraphviz->new(
graph => GraphViz->new(name => $title),
format => $save_format
);
# color map
my @cm = (
"#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3",
"#fdb462", "#b3de69", "#fccde5"
);
# make nodes
my %hidden_nodes;
for my $node (@{ $nodes })
{
my $op = $node->{op};
my $name = $node->{name};
# input data
my %attr = %node_attr;
my $label = $name;
if($op eq 'null')
{
if($name =~ /(?:_weight|_bias|_beta|_gamma|_moving_var|_moving_mean)$/)
{
if($hide_weights)
{
$hidden_nodes{$name} = 1;
}
# else we don't render a node, but
# don't add it to the hidden_nodes set
# so it gets rendered as an empty oval
next;
}
$attr{shape} = 'ellipse'; # inputs get their own shape
$label = $name;
$attr{fillcolor} = $cm[0];
}
elsif($op eq 'Convolution')
{
my @k = $node->{attr}{kernel} =~ /(\d+)/g;
my @stride = ($node->{attr}{stride}//'') =~ /(\d+)/g;
$stride[0] //= 1;
$label = "Convolution\n".join('x',@k).'/'.join('x',@stride).", $node->{attr}{num_filter}";
$attr{fillcolor} = $cm[1];
}
elsif($op eq 'FullyConnected')
{
$label = "FullyConnected\n$node->{attr}{num_hidden}";
$attr{fillcolor} = $cm[1];
}
elsif($op eq 'BatchNorm')
{
$attr{fillcolor} = $cm[3];
}
elsif($op eq 'Activation' or $op eq 'LeakyReLU')
{
$label = "$op\n$node->{attr}{act_type}";
$attr{fillcolor} = $cm[2];
}
elsif($op eq 'Pooling')
{
my @k = $node->{attr}{kernel} =~ /(\d+)/g;
my @stride = ($node->{attr}{stride}//'') =~ /(\d+)/g;
$stride[0] //= 1;
$label = "Pooling\n$node->{attr}{pool_type}, ".join('x',@k).'/'.join('x',@stride);
$attr{fillcolor} = $cm[4];
}
elsif($op eq 'Concat' or $op eq 'Flatten' or $op eq 'Reshape')
{
$attr{fillcolor} = $cm[5];
}
elsif($op eq 'Softmax')
{
$attr{fillcolor} = $cm[6];
}
else
{
$attr{fillcolor} = $cm[7];
if($op eq 'Custom')
{
$label = $node->{attr}{op_type};
}
}
$dot->graph->add_node($name, label => $label, %attr);
};
# add edges
for my $node (@{ $nodes })
{
my $op = $node->{op};
my $name = $node->{name};
if($op eq 'null')
{
next;
}
else
{
my $inputs = $node->{inputs};
for my $item (@{ $inputs })
{
my $input_node = $nodes->[$item->[0]];
my $input_name = $input_node->{name};
if(not exists $hidden_nodes{ $input_name })
{
my %attr = qw/dir back arrowtail normal/;
# add shapes
if($draw_shape)
{
my $key = $input_name;
$key .= '_output' if $input_node->{op} ne 'null';
my $end = @{ $shape_dict{$key} };
$attr{label} = join('x', @{ $shape_dict{$key} }[1..$end-1]);
}
$dot->graph->add_edge($name => $input_name, %attr);
}
}
}
}
return $dot;
}
package AI::MXNet::Visualization::PythonGraphviz;
use Mouse;
use AI::MXNet::Types;
has 'format' => (
is => 'ro',
isa => enum([qw/debug canon text ps hpgl pcl mif
pic gd gd2 gif jpeg png wbmp cmapx
imap vdx vrml vtx mp fig svg svgz
plain/]
)
);
has 'graph' => (is => 'ro', isa => 'GraphViz');
method render($output=)
{
my $method = 'as_' . $self->format;
return $self->graph->$method($output);
}
1;