AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/RNN.pm  view on Meta::CPAN

package AI::MXNet::RNN;
use strict;
use warnings;
use AI::MXNet::Function::Parameters;
use AI::MXNet::RNN::IO;
use AI::MXNet::RNN::Cell;
use List::Util qw(max);

=encoding UTF-8

=head1 NAME

    AI::MXNet::RNN - Functions for constructing recurrent neural networks.
=cut

=head1 SYNOPSIS


=head1 DESCRIPTION

    Functions for constructing recurrent neural networks.
=cut

=head2 save_rnn_checkpoint

    Save checkpoint for model using RNN cells.
    Unpacks weight before saving.

    Parameters
    ----------
    cells : AI::MXNet::RNN::Cell or array ref of AI::MXNet::RNN::Cell
        The RNN cells used by this symbol.
    prefix : str
        Prefix of model name.
    epoch : int
        The epoch number of the model.
    symbol : Symbol
        The input symbol
    arg_params : hash ref of str to AI::MXNet::NDArray
        Model parameter, hash ref of name to NDArray of net's weights.
    aux_params : hash ref of str to AI::MXNet::NDArray
        Model parameter, hash ref of name to NDArray of net's auxiliary states.

    Notes
    -----
    - prefix-symbol.json will be saved for symbol.
    - prefix-epoch.params will be saved for parameters.
=cut

method save_rnn_checkpoint(
    AI::MXNet::RNN::Cell::Base|ArrayRef[AI::MXNet::RNN::Cell::Base] $cells,
    Str                                                             $prefix,
    Int                                                             $epoch,
    AI::MXNet::Symbol                                               $symbol,
    HashRef[AI::MXNet::NDArray]                                     $arg_params,
    HashRef[AI::MXNet::NDArray]                                     $aux_params
)
{
    $cells = [$cells] unless ref $cells eq 'ARRAY';
    my %arg_params = %{ $arg_params };
    %arg_params = %{ $_->unpack_weights(\%arg_params) } for @{ $cells };
    AI::MXNet::Module->model_save_checkpoint($prefix, $epoch, $symbol, \%arg_params, $aux_params);
}


=head2 load_rnn_checkpoint

    Load model checkpoint from file.
    Pack weights after loading.



( run in 0.672 second using v1.01-cache-2.11-cpan-39bf76dae61 )