AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/TestUtils.pm  view on Meta::CPAN

                                      GetCifar10 pdl_maximum pdl_minimum mlp2 conv
                                      check_consistency zip assert enumerate same_array dies_like);
use constant default_numerical_threshold => 1e-6;
=head1 NAME

    AI::MXNet::TestUtils - Convenience subs used in tests.

=head2 same

    Test if two pdl arrays are the same

    Parameters
    ----------
    a : pdl
    b : pdl
=cut

func same(PDL $a, PDL $b)
{
    return ($a != $b)->sum == 0;
}

=head2 reldiff

    Calculate the relative difference between two input arrays

    Calculated by :math:`\\frac{|a-b|_1}{|a|_1 + |b|_1}`

    Parameters
    ----------
    a : pdl
    b : pdl
=cut

func reldiff(PDL $a, PDL $b)
{
    my $diff = sum(abs($a - $b));
    my $norm = sum(abs($a)) + sum(abs($b));
    if($diff == 0)
    {
        return 0;
    }
    my $ret = $diff / $norm;
    return $ret;
}

=head2 almost_equal

    Test if two pdl arrays are almost equal.
=cut

func almost_equal(PDL $a, PDL $b, Maybe[Num] $threshold=)
{
    $threshold //= default_numerical_threshold;
    my $rel = reldiff($a, $b);
    return $rel <= $threshold;
}

func GetMNIST_ubyte()
{
    if(not -d "data")
    {
        mkdir "data";
    }
    if (
        not -f 'data/train-images-idx3-ubyte'
            or
        not -f 'data/train-labels-idx1-ubyte'
            or
        not -f 'data/t10k-images-idx3-ubyte'
            or
        not -f 'data/t10k-labels-idx1-ubyte'
    )
    {
        `wget http://data.mxnet.io/mxnet/data/mnist.zip -P data`;
        chdir 'data';
        `unzip -u mnist.zip`;
        chdir '..';
    }
}

func GetCifar10()
{
    if(not -d "data")
    {
        mkdir "data";
    }
    if (not -f 'data/cifar10.zip')
    {
        `wget http://data.mxnet.io/mxnet/data/cifar10.zip -P data`;
        chdir 'data';
        `unzip -u cifar10.zip`;
        chdir '..';
    }
}

func _pdl_compare(PDL $a, PDL|Num $b, Str $criteria)
{
    if(not blessed $b)
    {
        my $tmp = $b;
        $b = $a->copy;
        $b .= $tmp;
    }
    my $mask = {
        'max' => sub { $_[0] < $_[1] },
        'min' => sub { $_[0] > $_[1] },
    }->{$criteria}->($a, $b);
    my $c = $a->copy;
    $c->where($mask) .= $b->where($mask);
    $c;
}

func pdl_maximum(PDL $a, PDL|Num $b)
{
    _pdl_compare($a, $b, 'max');
}

func pdl_minimum(PDL $a, PDL|Num $b)
{
    _pdl_compare($a, $b, 'min');
}

func mlp2()
{
    my $data = AI::MXNet::Symbol->Variable('data');
    my $out  = AI::MXNet::Symbol->FullyConnected(data=>$data, name=>'fc1', num_hidden=>1000);
    $out     = AI::MXNet::Symbol->Activation(data=>$out, act_type=>'relu');
    $out     = AI::MXNet::Symbol->FullyConnected(data=>$out, name=>'fc2', num_hidden=>10);
    return $out;
}

func conv()
{
    my $data    = AI::MXNet::Symbol->Variable('data');
    my $conv1   = AI::MXNet::Symbol->Convolution(data => $data, name=>'conv1', num_filter=>32, kernel=>[3,3], stride=>[2,2]);
    my $bn1     = AI::MXNet::Symbol->BatchNorm(data => $conv1, name=>"bn1");
    my $act1    = AI::MXNet::Symbol->Activation(data => $bn1, name=>'relu1', act_type=>"relu");
    my $mp1     = AI::MXNet::Symbol->Pooling(data => $act1, name => 'mp1', kernel=>[2,2], stride=>[2,2], pool_type=>'max');

    my $conv2   = AI::MXNet::Symbol->Convolution(data => $mp1, name=>'conv2', num_filter=>32, kernel=>[3,3], stride=>[2,2]);
    my $bn2     = AI::MXNet::Symbol->BatchNorm(data => $conv2, name=>"bn2");
    my $act2    = AI::MXNet::Symbol->Activation(data => $bn2, name=>'relu2', act_type=>"relu");
    my $mp2     = AI::MXNet::Symbol->Pooling(data => $act2, name => 'mp2', kernel=>[2,2], stride=>[2,2], pool_type=>'max');

    my $fl      = AI::MXNet::Symbol->Flatten(data => $mp2, name=>"flatten");
    my $fc2     = AI::MXNet::Symbol->FullyConnected(data => $fl, name=>'fc2', num_hidden=>10);
    my $softmax = AI::MXNet::Symbol->SoftmaxOutput(data => $fc2, name => 'sm');
    return $softmax;
}

=head2 check_consistency

    Check symbol gives the same output for different running context

    Parameters
    ----------
    sym : Symbol or list of Symbols
        symbol(s) to run the consistency test
    ctx_list : list
        running context. See example for more detail.
    scale : float, optional
        standard deviation of the inner normal distribution. Used in initialization
    grad_req : str or list of str or dict of str to str
        gradient requirement.
=cut

my %dtypes = (
    float32 => 0,
    float64 => 1,
    float16 => 2,
    uint8   => 3,
    int32   => 4
);

func check_consistency(
    SymbolOrArrayOfSymbols              :$sym,
    ArrayRef                            :$ctx_list,
    Num                                 :$scale=1,
    Str|ArrayRef[Str]|HashRef[Str]      :$grad_req='write',
    Maybe[HashRef[AI::MXNet::NDArray]]  :$arg_params=,
    Maybe[HashRef[AI::MXNet::NDArray]]  :$aux_params=,
    Maybe[HashRef[Num]|Num]             :$tol=,
    Bool                                :$raise_on_err=1,
    Maybe[AI::MXNer::NDArray]           :$ground_truth=
)
{
    $tol //= {
        float16 => 1e-1,
        float32 => 1e-3,
        float64 => 1e-5,
        uint8   => 0,
        int32   => 0
    };
    $tol = {
        float16 => $tol,
        float32 => $tol,
        float64 => $tol,
        uint8   => $tol,
        int32   => $tol
    } unless ref $tol;

    Test::More::ok(@$ctx_list > 1);
    if(blessed $sym)
    {
        $sym = [($sym)x@$ctx_list];
    }
    else



( run in 1.816 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )