AI-MXNet

 view release on metacpan or  search on metacpan

t/test_module.t  view on Meta::CPAN

use strict;
use warnings;
use Test::More tests => 257;
use AI::MXNet qw(mx);
use AI::MXNet::Base;
use AI::MXNet::TestUtils qw(almost_equal enumerate same_array dies_like);
use Data::Dumper;

sub test_module_layout
{
    my $sym = mx->sym->Variable('data');
    $sym = mx->sym->Activation(data=>$sym, act_type=>'relu', __layout__=>'TNC');

    my $dshape = [3, 8, 7];
    my $mod = mx->mod->Module(
        $sym,
        data_names=>['data'],
        context=>[mx->cpu(0), mx->cpu(1)]
    );
    $mod->bind(
        data_shapes=>[mx->io->DataDesc('data', $dshape, layout=>'TNC')]
    );
    $mod->init_params();
    $mod->forward(
        mx->io->DataBatch(
            data=>[mx->nd->ones($dshape)]
        ),
        is_train => 1
    );
    $mod->backward([mx->nd->ones($dshape)]);
    is_deeply($mod->get_outputs()->[0]->shape, $dshape);

    my $hdshape = [3, 4, 7];
    for my $x (@{ $mod->get_outputs(0)->[0] })
    {
        is_deeply($x->shape, $hdshape);
    }
}

sub test_save_load
{
    my $dict_equ = sub {
        is_deeply([sort keys %$a], [sort keys %$b]);
        for my $k (keys %$a)
        {
            ok(($a->{$k}->aspdl == $b->{$k}->aspdl)->all);
        }
    };
    my $sym = mx->sym->Variable('data');
    $sym = mx->sym->FullyConnected($sym, num_hidden=>100);

    # single device
    my $mod = mx->mod->Module($sym, data_names=>['data']);
    $mod->bind(data_shapes=>[['data', [10, 10]]]);
    $mod->init_params();
    $mod->init_optimizer(optimizer_params=>{learning_rate => 0.1, momentum => 0.9});
    $mod->update();
    $mod->save_checkpoint('test', 0, 1);

    my $mod2 = mx->mod->Module->load('test', 0, 1, data_names=>['data']);
    $mod2->bind(data_shapes=>[['data', [10, 10]]]);
    $mod2->init_optimizer(optimizer_params=>{learning_rate => 0.1, momentum => 0.9});
    is($mod->_symbol->tojson(), $mod2->_symbol->tojson());
    $dict_equ->(($mod->get_params())[0], ($mod2->get_params())[0]);
    $dict_equ->($mod->_updater->states, $mod2->_updater->states);

    # multi device

 view all matches for this distribution
 view release on metacpan -  search on metacpan

( run in 0.487 second using v1.00-cache-2.02-grep-82fe00e-cpan-2c419f77a38b )