AI-MXNet
view release on metacpan - search on metacpan
view release on metacpan or search on metacpan
t/test_module.t view on Meta::CPAN
use strict;
use warnings;
use Test::More tests => 257;
use AI::MXNet qw(mx);
use AI::MXNet::Base;
use AI::MXNet::TestUtils qw(almost_equal enumerate same_array dies_like);
use Data::Dumper;
sub test_module_layout
{
my $sym = mx->sym->Variable('data');
$sym = mx->sym->Activation(data=>$sym, act_type=>'relu', __layout__=>'TNC');
my $dshape = [3, 8, 7];
my $mod = mx->mod->Module(
$sym,
data_names=>['data'],
context=>[mx->cpu(0), mx->cpu(1)]
);
$mod->bind(
data_shapes=>[mx->io->DataDesc('data', $dshape, layout=>'TNC')]
);
$mod->init_params();
$mod->forward(
mx->io->DataBatch(
data=>[mx->nd->ones($dshape)]
),
is_train => 1
);
$mod->backward([mx->nd->ones($dshape)]);
is_deeply($mod->get_outputs()->[0]->shape, $dshape);
my $hdshape = [3, 4, 7];
for my $x (@{ $mod->get_outputs(0)->[0] })
{
is_deeply($x->shape, $hdshape);
}
}
sub test_save_load
{
my $dict_equ = sub {
is_deeply([sort keys %$a], [sort keys %$b]);
for my $k (keys %$a)
{
ok(($a->{$k}->aspdl == $b->{$k}->aspdl)->all);
}
};
my $sym = mx->sym->Variable('data');
$sym = mx->sym->FullyConnected($sym, num_hidden=>100);
# single device
my $mod = mx->mod->Module($sym, data_names=>['data']);
$mod->bind(data_shapes=>[['data', [10, 10]]]);
$mod->init_params();
$mod->init_optimizer(optimizer_params=>{learning_rate => 0.1, momentum => 0.9});
$mod->update();
$mod->save_checkpoint('test', 0, 1);
my $mod2 = mx->mod->Module->load('test', 0, 1, data_names=>['data']);
$mod2->bind(data_shapes=>[['data', [10, 10]]]);
$mod2->init_optimizer(optimizer_params=>{learning_rate => 0.1, momentum => 0.9});
is($mod->_symbol->tojson(), $mod2->_symbol->tojson());
$dict_equ->(($mod->get_params())[0], ($mod2->get_params())[0]);
$dict_equ->($mod->_updater->states, $mod2->_updater->states);
# multi device
view all matches for this distributionview release on metacpan - search on metacpan
( run in 0.487 second using v1.00-cache-2.02-grep-82fe00e-cpan-2c419f77a38b )