AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/RNN.pm view on Meta::CPAN
package AI::MXNet::RNN;
use strict;
use warnings;
use AI::MXNet::Function::Parameters;
use AI::MXNet::RNN::IO;
use AI::MXNet::RNN::Cell;
use List::Util qw(max);
=encoding UTF-8
=head1 NAME
AI::MXNet::RNN - Functions for constructing recurrent neural networks.
=cut
=head1 SYNOPSIS
=head1 DESCRIPTION
Functions for constructing recurrent neural networks.
=cut
=head2 save_rnn_checkpoint
Save checkpoint for model using RNN cells.
Unpacks weight before saving.
Parameters
----------
cells : AI::MXNet::RNN::Cell or array ref of AI::MXNet::RNN::Cell
The RNN cells used by this symbol.
prefix : str
Prefix of model name.
epoch : int
The epoch number of the model.
symbol : Symbol
The input symbol
arg_params : hash ref of str to AI::MXNet::NDArray
Model parameter, hash ref of name to NDArray of net's weights.
aux_params : hash ref of str to AI::MXNet::NDArray
Model parameter, hash ref of name to NDArray of net's auxiliary states.
Notes
-----
- prefix-symbol.json will be saved for symbol.
- prefix-epoch.params will be saved for parameters.
=cut
method save_rnn_checkpoint(
AI::MXNet::RNN::Cell::Base|ArrayRef[AI::MXNet::RNN::Cell::Base] $cells,
Str $prefix,
Int $epoch,
AI::MXNet::Symbol $symbol,
HashRef[AI::MXNet::NDArray] $arg_params,
HashRef[AI::MXNet::NDArray] $aux_params
)
{
$cells = [$cells] unless ref $cells eq 'ARRAY';
my %arg_params = %{ $arg_params };
%arg_params = %{ $_->unpack_weights(\%arg_params) } for @{ $cells };
AI::MXNet::Module->model_save_checkpoint($prefix, $epoch, $symbol, \%arg_params, $aux_params);
}
=head2 load_rnn_checkpoint
Load model checkpoint from file.
Pack weights after loading.
( run in 0.672 second using v1.01-cache-2.11-cpan-39bf76dae61 )