AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Visualization.pm  view on Meta::CPAN

        {
            for my $i (1..@{ $pre_node }-1)
            {
                $fields = ['', '', '', $pre_node->[$i]];
                $print_row->($fields, $positions);
            }
        }
        return $cur_param;
    };
    my $total_params = 0;
    enumerate(sub {
        my ($i, $node) = @_;
        my $out_shape = [];
        my $op = $node->{op};
        return if($op eq 'null' and $i > 0);
        if($op ne 'null' or exists $heads{$i})
        {
            if($show_shape)
            {
                my $key = $node->{name};
                $key .= '_output' if $op ne 'null';
                if(exists $shape_dict{ $key })
                {
                    my $end = @{ $shape_dict{ $key } };
                    @{ $out_shape } = @{ $shape_dict{ $key } }[1..$end-1];
                }
            }
        }
        $total_params += $print_layer_summary->($nodes->[$i], $out_shape);
        if($i == @{ $nodes } - 1)
        {
            print('=' x $line_length, "\n");
        }
        else
        {
            print('_' x $line_length, "\n");
        }
    }, $nodes);
    print("Total params: $total_params\n");
    print('_' x $line_length, "\n");
}

=head2 plot_network

    convert symbol to dot object for visualization

    Parameters
    ----------
    title: str
        title of the dot graph
    symbol: AI::MXNet::Symbol
        symbol to be visualized
    shape: HashRef[Shape]
        If supplied, the visualization will include the shape
        of each tensor on the edges between nodes.
    node_attrs: HashRef of node's attributes
        for example:
            {shape => "oval",fixedsize => "false"}
            means to plot the network in "oval"
    hide_weights: Bool
        if True (default) then inputs with names like `*_weight`
        or `*_bias` will be hidden

    Returns
    ------
    dot: Diagraph
        dot object of symbol
=cut


method plot_network(
    AI::MXNet::Symbol       $symbol,
    Str                    :$title='plot',
    Str                    :$save_format='ps',
    Maybe[HashRef[Shape]]  :$shape=,
    HashRef[Str]           :$node_attrs={},
    Bool                   :$hide_weights=1
)
{
    eval { require GraphViz; };
    Carp::confess("plot_network requires GraphViz module") if $@;
    my $draw_shape;
    my %shape_dict;
    if(defined $shape)
    {
        $draw_shape = 1;
        my $interals = $symbol->get_internals;
        my (undef, $out_shapes, undef) = $interals->infer_shape(%{ $shape });
        Carp::confess("Input shape is incomplete")
            unless defined $out_shapes;
        @shape_dict{ @{ $interals->list_outputs } } = @{ $out_shapes };
    }
    my $conf = decode_json($symbol->tojson);
    my $nodes = $conf->{nodes};
    my %node_attr = (
        qw/ shape box fixedsize true
            width 1.3 height 0.8034 style filled/,
        %{ $node_attrs }
    );
    my $dot = AI::MXNet::Visualization::PythonGraphviz->new(
        graph  => GraphViz->new(name => $title),
        format => $save_format
    );
    # color map
    my @cm = (
        "#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3",
        "#fdb462", "#b3de69", "#fccde5"
    );
    # make nodes
    my %hidden_nodes;
    for my $node (@{ $nodes })
    {
        my $op   = $node->{op};
        my $name = $node->{name};
        # input data
        my %attr = %node_attr;
        my $label = $name;
        if($op eq 'null')
        {
            if($name =~ /(?:_weight|_bias|_beta|_gamma|_moving_var|_moving_mean)$/)
            {



( run in 0.557 second using v1.01-cache-2.11-cpan-39bf76dae61 )