AI-MXNet
view release on metacpan or search on metacpan
t/test_rnn.t view on Meta::CPAN
use strict;
use warnings;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(same);
use PDL;
use Test::More tests => 54;
sub test_rnn
{
my $cell = mx->rnn->RNNCell(100, prefix=>'rnn_');
my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
$outputs = mx->sym->Group($outputs);
is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}
sub test_lstm
{
my $cell = mx->rnn->LSTMCell(100, prefix=>'rnn_', forget_bias => 1);
my($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
$outputs = mx->sym->Group($outputs);
is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}
sub test_lstm_forget_bias
{
my $forget_bias = 2;
my $stack = mx->rnn->SequentialRNNCell();
$stack->add(mx->rnn->LSTMCell(100, forget_bias=>$forget_bias, prefix=>'l0_'));
$stack->add(mx->rnn->LSTMCell(100, forget_bias=>$forget_bias, prefix=>'l1_'));
my $dshape = [32, 1, 200];
my $data = mx->sym->Variable('data');
my ($sym) = $stack->unroll(1, inputs => $data, merge_outputs => 1);
my $mod = mx->mod->Module($sym, context => mx->cpu(0));
$mod->bind(data_shapes=>[['data', $dshape]]);
$mod->init_params();
my ($bias_argument) = grep { /i2h_bias$/ } @{ $sym->list_arguments };
my $f = zeros(100);
my $expected_bias = $f->glue(0, $forget_bias * ones(100), zeros(200));
ok(
((($mod->get_params())[0]->{$bias_argument}->aspdl - $expected_bias)->abs < 1e-07)->all
);
}
sub test_gru
{
my $cell = mx->rnn->GRUCell(100, prefix=>'rnn_');
my($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
$outputs = mx->sym->Group($outputs);
is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}
sub test_residual
{
my $cell = mx->rnn->ResidualCell(mx->rnn->GRUCell(50, prefix=>'rnn_'));
( run in 2.594 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )