AI-MXNet

 view release on metacpan or  search on metacpan

t/test_rnn.t  view on Meta::CPAN

use strict;
use warnings;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(same);
use PDL;
use Test::More tests => 54;

sub test_rnn
{
    my $cell = mx->rnn->RNNCell(100, prefix=>'rnn_');
    my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}

sub test_lstm
{
    my $cell = mx->rnn->LSTMCell(100, prefix=>'rnn_', forget_bias => 1);
    my($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}

sub test_lstm_forget_bias
{
    my $forget_bias = 2;
    my $stack = mx->rnn->SequentialRNNCell();
    $stack->add(mx->rnn->LSTMCell(100, forget_bias=>$forget_bias, prefix=>'l0_'));
    $stack->add(mx->rnn->LSTMCell(100, forget_bias=>$forget_bias, prefix=>'l1_'));

    my $dshape = [32, 1, 200];
    my $data   = mx->sym->Variable('data');

    my ($sym) = $stack->unroll(1, inputs => $data, merge_outputs => 1);
    my $mod = mx->mod->Module($sym, context => mx->cpu(0));
    $mod->bind(data_shapes=>[['data', $dshape]]);

    $mod->init_params();
    my ($bias_argument) = grep { /i2h_bias$/ } @{ $sym->list_arguments };
    my $f = zeros(100);
    my $expected_bias = $f->glue(0, $forget_bias * ones(100), zeros(200));
    ok(
        ((($mod->get_params())[0]->{$bias_argument}->aspdl - $expected_bias)->abs < 1e-07)->all
    );
}

sub test_gru
{
    my $cell = mx->rnn->GRUCell(100, prefix=>'rnn_');
    my($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}

sub test_residual
{
    my $cell = mx->rnn->ResidualCell(mx->rnn->GRUCell(50, prefix=>'rnn_'));
    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1];
    my ($outputs)= $cell->unroll(2, inputs => $inputs);
    $outputs = mx->sym->Group($outputs);
    is_deeply(
        [sort keys %{ $cell->params->_params }],
        ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
    );
    is_deeply(
        $outputs->list_outputs,
        ['rnn_t0_out_plus_residual_output', 'rnn_t1_out_plus_residual_output']
    );

    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10, 50], rnn_t1_data=>[10, 50]);
    is_deeply($outs, [[10, 50], [10, 50]]);
    $outputs = $outputs->eval(args => {
        rnn_t0_data=>mx->nd->ones([10, 50]),
        rnn_t1_data=>mx->nd->ones([10, 50]),
        rnn_i2h_weight=>mx->nd->zeros([150, 50]),
        rnn_i2h_bias=>mx->nd->zeros([150]),
        rnn_h2h_weight=>mx->nd->zeros([150, 50]),
        rnn_h2h_bias=>mx->nd->zeros([150])
    });
    my $expected_outputs = mx->nd->ones([10, 50])->aspdl;
    same(@{$outputs}[0]->aspdl, $expected_outputs);
    same(@{$outputs}[1]->aspdl, $expected_outputs);
}

sub test_residual_bidirectional
{
    my $cell = mx->rnn->ResidualCell(
        mx->rnn->BidirectionalCell(
            mx->rnn->GRUCell(25, prefix=>'rnn_l_'),
            mx->rnn->GRUCell(25, prefix=>'rnn_r_')
        )
    );
    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1];



( run in 0.958 second using v1.01-cache-2.11-cpan-39bf76dae61 )