AI-MXNet

 view release on metacpan or  search on metacpan

t/test_executor.t  view on Meta::CPAN

use strict;
use warnings;
use Test::More tests => 2283;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(reldiff pdl_maximum pdl_minimum);
use PDL;

sub check_bind_with_uniform
{
    my ($uf, $gf, $dim, $sf, $lshape, $rshape) = @_;
    my $shape = (random($dim)*int(1000**(1.0/$dim))+1)->floor->unpdl;
    my $lhs = mx->symbol->Variable('lhs');
    my $rhs = mx->symbol->Variable('rhs');
    my $ret;
    if(defined $sf)
    {
        $ret = &{$sf}($lhs, $rhs);
    }
    else
    {
        $ret = &{$uf}($lhs, $rhs);
    }

    is_deeply($ret->list_arguments(), ['lhs', 'rhs']);
    $lshape //= $shape;
    $rshape //= $shape;

    my $lhs_arr = mx->nd->array(random(reverse (@$lshape)));
    my $rhs_arr = mx->nd->array(random(reverse (@$rshape)));
    my $lhs_grad = mx->nd->empty($lshape);
    my $rhs_grad = mx->nd->empty($rshape);
    my $executor = $ret->bind(
        ctx       => mx->Context('cpu'),
        args      => [$lhs_arr, $rhs_arr],
        args_grad => [$lhs_grad, $rhs_grad]
    );

    my $exec3 = $ret->bind(
        ctx  => mx->Context('cpu'),
        args => [$lhs_arr, $rhs_arr]
    );

    my $exec4 = $ret->bind(
        ctx  => mx->Context('cpu'),
        args => {'rhs' => $rhs_arr, 'lhs' => $lhs_arr},
        args_grad=>{'lhs' => $lhs_grad, 'rhs' => $rhs_grad}
    );

    $executor->forward(1);
    $exec3->forward(1);
    $exec4->forward(1);
    my $out2 = $executor->outputs->[0]->aspdl;
    my $out1 = &{$uf}($lhs_arr->aspdl, $rhs_arr->aspdl);
    my $out3 = $exec3->outputs->[0]->aspdl;
    my $out4 = $exec4->outputs->[0]->aspdl;
    ok(reldiff($out1, $out2) < 1e-6);
    ok(reldiff($out1, $out3) < 1e-6);
    ok(reldiff($out1, $out4) < 1e-6);
    # test gradient

    my $out_grad = mx->nd->ones([reverse @{$out2->shape->unpdl}]);
    my ($lhs_grad2, $rhs_grad2) = &{$gf}(
        $out_grad->aspdl,
        $lhs_arr->aspdl,
        $rhs_arr->aspdl



( run in 0.918 second using v1.01-cache-2.11-cpan-39bf76dae61 )