AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/TestUtils.pm  view on Meta::CPAN

package AI::MXNet::TestUtils;
use strict;
use warnings;
use PDL;
use Carp;
use Scalar::Util qw(blessed);
use AI::MXNet::Function::Parameters;
use Exporter;
use base qw(Exporter);
@AI::MXNet::TestUtils::EXPORT_OK = qw(same reldiff almost_equal GetMNIST_ubyte
                                      GetCifar10 pdl_maximum pdl_minimum mlp2 conv
                                      check_consistency zip assert enumerate same_array dies_like);
use constant default_numerical_threshold => 1e-6;
=head1 NAME

    AI::MXNet::TestUtils - Convenience subs used in tests.

=head2 same

    Test if two pdl arrays are the same

    Parameters
    ----------
    a : pdl
    b : pdl
=cut

func same(PDL $a, PDL $b)
{
    return ($a != $b)->sum == 0;
}

=head2 reldiff

    Calculate the relative difference between two input arrays

    Calculated by :math:`\\frac{|a-b|_1}{|a|_1 + |b|_1}`

    Parameters
    ----------
    a : pdl
    b : pdl
=cut

func reldiff(PDL $a, PDL $b)
{
    my $diff = sum(abs($a - $b));
    my $norm = sum(abs($a)) + sum(abs($b));
    if($diff == 0)
    {
        return 0;
    }
    my $ret = $diff / $norm;
    return $ret;
}

=head2 almost_equal

    Test if two pdl arrays are almost equal.
=cut

func almost_equal(PDL $a, PDL $b, Maybe[Num] $threshold=)
{
    $threshold //= default_numerical_threshold;
    my $rel = reldiff($a, $b);
    return $rel <= $threshold;
}

func GetMNIST_ubyte()
{
    if(not -d "data")
    {
        mkdir "data";
    }
    if (
        not -f 'data/train-images-idx3-ubyte'
            or
        not -f 'data/train-labels-idx1-ubyte'
            or
        not -f 'data/t10k-images-idx3-ubyte'
            or
        not -f 'data/t10k-labels-idx1-ubyte'
    )
    {
        `wget http://data.mxnet.io/mxnet/data/mnist.zip -P data`;
        chdir 'data';
        `unzip -u mnist.zip`;
        chdir '..';
    }
}

func GetCifar10()
{
    if(not -d "data")
    {
        mkdir "data";
    }
    if (not -f 'data/cifar10.zip')
    {
        `wget http://data.mxnet.io/mxnet/data/cifar10.zip -P data`;
        chdir 'data';
        `unzip -u cifar10.zip`;
        chdir '..';
    }
}

func _pdl_compare(PDL $a, PDL|Num $b, Str $criteria)
{
    if(not blessed $b)
    {
        my $tmp = $b;
        $b = $a->copy;
        $b .= $tmp;
    }
    my $mask = {
        'max' => sub { $_[0] < $_[1] },
        'min' => sub { $_[0] > $_[1] },
    }->{$criteria}->($a, $b);
    my $c = $a->copy;
    $c->where($mask) .= $b->where($mask);
    $c;
}

func pdl_maximum(PDL $a, PDL|Num $b)



( run in 0.740 second using v1.01-cache-2.11-cpan-39bf76dae61 )