AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/RNN.pm view on Meta::CPAN
The symbol configuration of computation network.
arg_params : hash ref of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
aux_params : hash ref of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
Notes
-----
- symbol will be loaded from prefix-symbol.json.
- parameters will be loaded from prefix-epoch.params.
=cut
method load_rnn_checkpoint(
AI::MXNet::RNN::Cell::Base|ArrayRef[AI::MXNet::RNN::Cell::Base] $cells,
Str $prefix,
Int $epoch
)
{
my ($sym, $arg, $aux) = AI::MXNet::Module->load_checkpoint($prefix, $epoch);
$cells = [$cells] unless ref $cells eq 'ARRAY';
$arg = $_->pack_weights($arg) for @{ $cells };
return ($sym, $arg, $aux);
}
=head2 do_rnn_checkpoint
Make a callback to checkpoint Module to prefix every epoch.
unpacks weights used by cells before saving.
Parameters
----------
cells : subclass of RNN::Cell
RNN cells used by this module.
prefix : str
The file prefix to checkpoint to
period : int
How many epochs to wait before checkpointing. Default is 1.
Returns
-------
callback : function
The callback function that can be passed as iter_end_callback to fit.
=cut
method do_rnn_checkpoint(
AI::MXNet::RNN::Cell::Base|ArrayRef[AI::MXNet::RNN::Cell::Base] $cells,
Str $prefix,
Int $period
)
{
$period = max(1, $period);
return sub {
my ($iter_no, $sym, $arg, $aux) = @_;
if (($iter_no + 1) % $period == 0)
{
__PACKAGE__->save_rnn_checkpoint($cells, $prefix, $iter_no+1, $sym, $arg, $aux);
}
};
}
## In order to closely resemble the Python's usage
method RNNCell(@args) { AI::MXNet::RNN::Cell->new(@args % 2 ? ('num_hidden', @args) : @args) }
method LSTMCell(@args) { AI::MXNet::RNN::LSTMCell->new(@args % 2 ? ('num_hidden', @args) : @args) }
method GRUCell(@args) { AI::MXNet::RNN::GRUCell->new(@args % 2 ? ('num_hidden', @args) : @args) }
method FusedRNNCell(@args) { AI::MXNet::RNN::FusedCell->new(@args % 2 ? ('num_hidden', @args) : @args) }
method SequentialRNNCell(@args) { AI::MXNet::RNN::SequentialCell->new(@args) }
method BidirectionalCell(@args) { AI::MXNet::RNN::BidirectionalCell->new(@args) }
method DropoutCell(@args) { AI::MXNet::RNN::DropoutCell->new(@args) }
method ZoneoutCell(@args) { AI::MXNet::RNN::ZoneoutCell->new(@args) }
method ConvRNNCell(@args) { AI::MXNet::RNN::ConvCell->new(@args) }
method ConvLSTMCell(@args) { AI::MXNet::RNN::ConvLSTMCell->new(@args) }
method ConvGRUCell(@args) { AI::MXNet::RNN::ConvGRUCell->new(@args) }
method ResidualCell(@args) { AI::MXNet::RNN::ResidualCell->new(@args) }
method encode_sentences(@args) { AI::MXNet::RNN::IO->encode_sentences(@args) }
method BucketSentenceIter(@args)
{
my $sentences = shift(@args);
my $batch_size = shift(@args);
AI::MXNet::BucketSentenceIter->new(sentences => $sentences, batch_size => $batch_size, @args);
}
1;
( run in 0.887 second using v1.01-cache-2.11-cpan-39bf76dae61 )