AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Visualization.pm view on Meta::CPAN
{
for my $i (1..@{ $pre_node }-1)
{
$fields = ['', '', '', $pre_node->[$i]];
$print_row->($fields, $positions);
}
}
return $cur_param;
};
my $total_params = 0;
enumerate(sub {
my ($i, $node) = @_;
my $out_shape = [];
my $op = $node->{op};
return if($op eq 'null' and $i > 0);
if($op ne 'null' or exists $heads{$i})
{
if($show_shape)
{
my $key = $node->{name};
$key .= '_output' if $op ne 'null';
if(exists $shape_dict{ $key })
{
my $end = @{ $shape_dict{ $key } };
@{ $out_shape } = @{ $shape_dict{ $key } }[1..$end-1];
}
}
}
$total_params += $print_layer_summary->($nodes->[$i], $out_shape);
if($i == @{ $nodes } - 1)
{
print('=' x $line_length, "\n");
}
else
{
print('_' x $line_length, "\n");
}
}, $nodes);
print("Total params: $total_params\n");
print('_' x $line_length, "\n");
}
=head2 plot_network
convert symbol to dot object for visualization
Parameters
----------
title: str
title of the dot graph
symbol: AI::MXNet::Symbol
symbol to be visualized
shape: HashRef[Shape]
If supplied, the visualization will include the shape
of each tensor on the edges between nodes.
node_attrs: HashRef of node's attributes
for example:
{shape => "oval",fixedsize => "false"}
means to plot the network in "oval"
hide_weights: Bool
if True (default) then inputs with names like `*_weight`
or `*_bias` will be hidden
Returns
------
dot: Diagraph
dot object of symbol
=cut
method plot_network(
AI::MXNet::Symbol $symbol,
Str :$title='plot',
Str :$save_format='ps',
Maybe[HashRef[Shape]] :$shape=,
HashRef[Str] :$node_attrs={},
Bool :$hide_weights=1
)
{
eval { require GraphViz; };
Carp::confess("plot_network requires GraphViz module") if $@;
my $draw_shape;
my %shape_dict;
if(defined $shape)
{
$draw_shape = 1;
my $interals = $symbol->get_internals;
my (undef, $out_shapes, undef) = $interals->infer_shape(%{ $shape });
Carp::confess("Input shape is incomplete")
unless defined $out_shapes;
@shape_dict{ @{ $interals->list_outputs } } = @{ $out_shapes };
}
my $conf = decode_json($symbol->tojson);
my $nodes = $conf->{nodes};
my %node_attr = (
qw/ shape box fixedsize true
width 1.3 height 0.8034 style filled/,
%{ $node_attrs }
);
my $dot = AI::MXNet::Visualization::PythonGraphviz->new(
graph => GraphViz->new(name => $title),
format => $save_format
);
# color map
my @cm = (
"#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3",
"#fdb462", "#b3de69", "#fccde5"
);
# make nodes
my %hidden_nodes;
for my $node (@{ $nodes })
{
my $op = $node->{op};
my $name = $node->{name};
# input data
my %attr = %node_attr;
my $label = $name;
if($op eq 'null')
{
if($name =~ /(?:_weight|_bias|_beta|_gamma|_moving_var|_moving_mean)$/)
{
( run in 0.557 second using v1.01-cache-2.11-cpan-39bf76dae61 )