AI-MXNet

 view release on metacpan or  search on metacpan

t/test_symbol.t  view on Meta::CPAN

use strict;
use warnings;
use Test::More tests => 98;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(mlp2 conv check_consistency zip assert enumerate);
use Storable qw(freeze thaw);
use PDL;

sub test_symbol_compose
{
    my $data = mx->symbol->Variable('data');
    my $net1 = mx->symbol->FullyConnected(data=>$data, name=>'fc1', num_hidden=>10);
    $net1 = mx->symbol->FullyConnected(data=>$net1, name=>'fc2', num_hidden=>100);
    is_deeply($net1->list_arguments(), ['data',
                              'fc1_weight', 'fc1_bias',
                              'fc2_weight', 'fc2_bias']);

    my $net2 = mx->symbol->FullyConnected(name=>'fc3', num_hidden=>10);
    $net2 = mx->symbol->Activation(data=>$net2, act_type=>'relu');
    $net2 = mx->symbol->FullyConnected(data=>$net2, name=>'fc4', num_hidden=>20);
    my $composed = &{$net2}(fc3_data=>$net1, name=>'composed');
    my $multi_out = mx->symbol->Group([$composed, $net1]);
    ok(@{ $multi_out->list_outputs() } == 2);
}

test_symbol_compose();

sub test_symbol_copy
{
    my $data = mx->symbol->Variable('data');
    my $data_2 = $data->deepcopy;
    is($data->tojson, $data_2->tojson);
}

test_symbol_copy();

sub test_symbol_internal
{
    my $data = mx->symbol->Variable('data');
    my $oldfc = mx->symbol->FullyConnected(data=>$data, name=>'fc1', num_hidden=>10);
    my $net1 = mx->symbol->FullyConnected(data=>$oldfc, name=>'fc2', num_hidden=>100);
    is_deeply($net1->list_arguments, ['data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias']);

    my $internal = $net1->get_internals();
    my $fc1 = $internal->slice('fc1_output');
    is_deeply($fc1->list_arguments, $oldfc->list_arguments);
}

test_symbol_internal();

sub test_symbol_children
{
    my $data = mx->symbol->Variable('data');
    my $oldfc = mx->symbol->FullyConnected(data=>$data, name=>'fc1', num_hidden=>10);
    my $net1 = mx->symbol->FullyConnected(data=>$oldfc, name=>'fc2', num_hidden=>100);

    is_deeply($net1->get_children()->list_outputs(), ['fc1_output', 'fc2_weight', 'fc2_bias']);
    is_deeply($net1->get_children()->get_children()->list_outputs() , ['data', 'fc1_weight', 'fc1_bias']);
    is_deeply($net1->get_children()->slice('fc2_weight')->list_arguments(), ['fc2_weight']);
    ok(not defined $net1->get_children()->slice('fc2_weight')->get_children());

    $data = mx->sym->Variable('data');
    my $sliced = mx->sym->SliceChannel($data, num_outputs=>3, name=>'slice');
    my $concat = mx->sym->Concat(@{ $sliced });

t/test_symbol.t  view on Meta::CPAN

    my $data2 = mx->symbol->load($fname);
    # save because of order
    is($sym->tojson, $data2->tojson);
    unlink $fname;
}

test_symbol_saveload();

sub test_symbol_infer_type
{
    my $data = mx->symbol->Variable('data');
    my $f32data = mx->symbol->Cast(data=>$data, dtype=>'float32');
    my $fc1 = mx->symbol->FullyConnected(data => $f32data, name=>'fc1', num_hidden=>128);
    my $mlp = mx->symbol->SoftmaxOutput(data => $fc1, name => 'softmax');

    my ($arg, $out, $aux) = $mlp->infer_type(data=>'float16');
    is_deeply($arg, [qw/float16 float32 float32 float32/]);
    is_deeply($out, ['float32']);
    is_deeply($aux, []);
}

test_symbol_infer_type();

sub test_symbol_infer_shape
{
    my $num_hidden = 128;
    my $num_dim    = 64;
    my $num_sample = 10;

    my $data = mx->symbol->Variable('data');
    my $prev = mx->symbol->Variable('prevstate');
    my $x2h  = mx->symbol->FullyConnected(data=>$data, name=>'x2h', num_hidden=>$num_hidden);
    my $h2h  = mx->symbol->FullyConnected(data=>$prev, name=>'h2h', num_hidden=>$num_hidden);

    my $out  = mx->symbol->Activation(data=>mx->sym->elemwise_add($x2h, $h2h), name=>'out', act_type=>'relu');

    # shape inference will fail because information is not available for h2h
    my @ret  = $out->infer_shape(data=>[$num_sample, $num_dim]);
    is_deeply(\@ret, [undef, undef, undef]);

    my ($arg_shapes, $out_shapes, $aux_shapes) = $out->infer_shape_partial(data=>[$num_sample, $num_dim]);
    my %arg_shapes;
    @arg_shapes{ @{ $out->list_arguments } } = @{ $arg_shapes };
    is_deeply($arg_shapes{data}, [$num_sample, $num_dim]);
    is_deeply($arg_shapes{x2h_weight}, [$num_hidden, $num_dim]);
    is_deeply($arg_shapes{h2h_weight}, []);

    # now we can do full shape inference
    my $state_shape = $out_shapes->[0];
    ($arg_shapes, $out_shapes, $aux_shapes) = $out->infer_shape(data=>[$num_sample, $num_dim], prevstate=>$state_shape);
    @arg_shapes{ @{ $out->list_arguments } } = @{ $arg_shapes };
    is_deeply($arg_shapes{data}, [$num_sample, $num_dim]);
    is_deeply($arg_shapes{x2h_weight}, [$num_hidden, $num_dim]);
    is_deeply($arg_shapes{h2h_weight}, [$num_hidden, $num_hidden]);
}

test_symbol_infer_shape();

sub test_symbol_infer_shape_var
{
    #Test specifying shape information when constructing a variable
    my $shape = [2, 3];
    my $a = mx->symbol->Variable('a', shape=>$shape);
    my $b = mx->symbol->Variable('b');
    my $c = mx->symbol->elemwise_add($a, $b);
    my ($arg_shapes, $out_shapes, $aux_shapes) = $c->infer_shape();
    is_deeply($arg_shapes->[0], $shape);
    is_deeply($arg_shapes->[1], $shape);
    is_deeply($out_shapes->[0], $shape);

    $shape = [5, 6];
    ($arg_shapes, $out_shapes, $aux_shapes) = $c->infer_shape(a=>$shape);
    is_deeply($arg_shapes->[0], $shape);
    is_deeply($arg_shapes->[1], $shape);
    is_deeply($out_shapes->[0], $shape);
}

test_symbol_infer_shape_var();

sub check_symbol_consistency
{
    my ($sym1, $sym2, $ctx) = @_;
    is_deeply($sym1->list_arguments(), $sym2->list_arguments());
    is_deeply($sym1->list_auxiliary_states(), $sym2->list_auxiliary_states());
    is_deeply($sym1->list_outputs(), $sym2->list_outputs());
    check_consistency(sym => [$sym1, $sym2], ctx_list => [$ctx, $ctx]);
}

sub test_load_000800
{
    my ($data, $weight, $fc1, $act1);
    {
        local($mx::AttrScope) = mx->AttrScope(ctx_group=>'stage1');
        $data = mx->symbol->Variable('data', lr_mult=>0.2);
        $weight = mx->sym->Variable('fc1_weight', lr_mult=>1.2);
        $fc1  = mx->symbol->FullyConnected(data => $data, weight=>$weight, name=>'fc1', num_hidden=>128, wd_mult=>0.3);
        $act1 = mx->symbol->Activation(data => $fc1, name=>'relu1', act_type=>"relu");
    }
    my ($fc2, $act2, $fc3, $sym1);
    {
        local($mx::AttrScope) = mx->AttrScope(ctx_group=>'stage2');
        $fc2  = mx->symbol->FullyConnected(data => $act1, name => 'fc2', num_hidden => 64, lr_mult=>0.01);
        $act2 = mx->symbol->Activation(data => $fc2, name=>'relu2', act_type=>"relu");
        $fc3  = mx->symbol->FullyConnected(data => $act2, name=>'fc3', num_hidden=>10);
        $fc3  = mx->symbol->BatchNorm($fc3, name=>'batchnorm0');
        $sym1 = mx->symbol->SoftmaxOutput(data => $fc3, name => 'softmax')
    }
    { local $/ = undef; my $json = <DATA>; open(F, ">save_000800.json"); print F $json; close(F); };
    my $sym2 = mx->sym->load('save_000800.json');
    unlink 'save_000800.json';

    my %attr1 = %{ $sym1->attr_dict };
    my %attr2 = %{ $sym2->attr_dict };
    while(my ($k, $v1) = each %attr1)
    {
        ok(exists $attr2{ $k });
        my $v2 = $attr2{$k};
        while(my ($kk, $vv1) = each %{ $v1 })
        {
            if($kk =~ /^__/ and $kk =~ /__$/)
            {



( run in 1.597 second using v1.01-cache-2.11-cpan-39bf76dae61 )