AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/RNN.pm  view on Meta::CPAN

        The symbol configuration of computation network.
    arg_params : hash ref of str to NDArray
        Model parameter, dict of name to NDArray of net's weights.
    aux_params : hash ref of str to NDArray
        Model parameter, dict of name to NDArray of net's auxiliary states.

    Notes
    -----
    - symbol will be loaded from prefix-symbol.json.
    - parameters will be loaded from prefix-epoch.params.
=cut

method load_rnn_checkpoint(
    AI::MXNet::RNN::Cell::Base|ArrayRef[AI::MXNet::RNN::Cell::Base] $cells,
    Str                                                             $prefix,
    Int                                                             $epoch
)
{
    my ($sym, $arg, $aux) = AI::MXNet::Module->load_checkpoint($prefix, $epoch);
    $cells = [$cells] unless ref $cells eq 'ARRAY';
    $arg = $_->pack_weights($arg) for @{ $cells };
    return ($sym, $arg, $aux);
}

=head2 do_rnn_checkpoint

    Make a callback to checkpoint Module to prefix every epoch.
    unpacks weights used by cells before saving.

    Parameters
    ----------
    cells : subclass of RNN::Cell
        RNN cells used by this module.
    prefix : str
        The file prefix to checkpoint to
    period : int
        How many epochs to wait before checkpointing. Default is 1.

    Returns
    -------
    callback : function
        The callback function that can be passed as iter_end_callback to fit.
=cut

method do_rnn_checkpoint(
    AI::MXNet::RNN::Cell::Base|ArrayRef[AI::MXNet::RNN::Cell::Base]  $cells,
    Str                                                              $prefix,
    Int                                                              $period
)
{
    $period = max(1, $period);
    return sub {
        my ($iter_no, $sym, $arg, $aux) = @_;
        if (($iter_no + 1) % $period == 0)
        {
            __PACKAGE__->save_rnn_checkpoint($cells, $prefix, $iter_no+1, $sym, $arg, $aux);
        }
    };
}

## In order to closely resemble the Python's usage
method RNNCell(@args)            { AI::MXNet::RNN::Cell->new(@args % 2 ? ('num_hidden', @args) : @args) }
method LSTMCell(@args)           { AI::MXNet::RNN::LSTMCell->new(@args % 2 ? ('num_hidden', @args) : @args) }
method GRUCell(@args)            { AI::MXNet::RNN::GRUCell->new(@args % 2 ? ('num_hidden', @args) : @args) }
method FusedRNNCell(@args)       { AI::MXNet::RNN::FusedCell->new(@args % 2 ? ('num_hidden', @args) : @args) }
method SequentialRNNCell(@args)  { AI::MXNet::RNN::SequentialCell->new(@args) }
method BidirectionalCell(@args)  { AI::MXNet::RNN::BidirectionalCell->new(@args) }
method DropoutCell(@args)        { AI::MXNet::RNN::DropoutCell->new(@args) }
method ZoneoutCell(@args)        { AI::MXNet::RNN::ZoneoutCell->new(@args) }
method ConvRNNCell(@args)        { AI::MXNet::RNN::ConvCell->new(@args) }
method ConvLSTMCell(@args)       { AI::MXNet::RNN::ConvLSTMCell->new(@args) }
method ConvGRUCell(@args)        { AI::MXNet::RNN::ConvGRUCell->new(@args) }
method ResidualCell(@args)       { AI::MXNet::RNN::ResidualCell->new(@args) }
method encode_sentences(@args)   { AI::MXNet::RNN::IO->encode_sentences(@args) }
method BucketSentenceIter(@args)
{
    my $sentences  = shift(@args);
    my $batch_size = shift(@args);
    AI::MXNet::BucketSentenceIter->new(sentences => $sentences, batch_size => $batch_size, @args);
}

1;



( run in 0.887 second using v1.01-cache-2.11-cpan-39bf76dae61 )