AI-MXNet

 view release on metacpan or  search on metacpan

t/test_rnn.t  view on Meta::CPAN

use strict;
use warnings;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(same);
use PDL;
use Test::More tests => 54;

sub test_rnn
{
    my $cell = mx->rnn->RNNCell(100, prefix=>'rnn_');
    my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}

sub test_lstm
{
    my $cell = mx->rnn->LSTMCell(100, prefix=>'rnn_', forget_bias => 1);
    my($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}

sub test_lstm_forget_bias
{
    my $forget_bias = 2;
    my $stack = mx->rnn->SequentialRNNCell();
    $stack->add(mx->rnn->LSTMCell(100, forget_bias=>$forget_bias, prefix=>'l0_'));
    $stack->add(mx->rnn->LSTMCell(100, forget_bias=>$forget_bias, prefix=>'l1_'));

    my $dshape = [32, 1, 200];
    my $data   = mx->sym->Variable('data');

    my ($sym) = $stack->unroll(1, inputs => $data, merge_outputs => 1);
    my $mod = mx->mod->Module($sym, context => mx->cpu(0));
    $mod->bind(data_shapes=>[['data', $dshape]]);

    $mod->init_params();
    my ($bias_argument) = grep { /i2h_bias$/ } @{ $sym->list_arguments };
    my $f = zeros(100);
    my $expected_bias = $f->glue(0, $forget_bias * ones(100), zeros(200));
    ok(
        ((($mod->get_params())[0]->{$bias_argument}->aspdl - $expected_bias)->abs < 1e-07)->all
    );
}

sub test_gru
{
    my $cell = mx->rnn->GRUCell(100, prefix=>'rnn_');
    my($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}

sub test_residual
{
    my $cell = mx->rnn->ResidualCell(mx->rnn->GRUCell(50, prefix=>'rnn_'));



( run in 2.594 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )