    ## Convolutional NN for recognizing hand-written digits in MNIST dataset
    ## It's considered "Hello, World" for Neural Networks
    ## For more info about the MNIST problem please refer to

    use strict;
    use warnings;
    use AI::MXNet qw(mx);
    use AI::MXNet::TestUtils qw(GetMNIST_ubyte);
    use Test::More tests => 1;

    # symbol net
    my $batch_size = 100;

    ### model
    my $data = mx->symbol->Variable('data');
    my $conv1= mx->symbol->Convolution(data => $data, name => 'conv1', num_filter => 32, kernel => [3,3], stride => [2,2]);
    my $bn1  = mx->symbol->BatchNorm(data => $conv1, name => "bn1");
    my $act1 = mx->symbol->Activation(data => $bn1, name => 'relu1', act_type => "relu");
    my $mp1  = mx->symbol->Pooling(data => $act1, name => 'mp1', kernel => [2,2], stride =>[2,2], pool_type=>'max');

                        ($acc_g + $self->epsilon)->sqrt
    $acc_delta .= $self->rho * $acc_delta + (1 - $self->rho) * $current_delta * $current_delta;
    $weight -= $current_delta + $wd * $weight;


# For test use
package AI::MXNet::Test;
use Mouse;

extends 'AI::MXNet::Optimizer';

# Create a state to duplicate weight
method create_state(Index $index, AI::MXNet::NDArray $weight)
    return AI::MXNet::NDArray->zeros(
                ctx => $weight->context

package AI::MXNet::TestUtils;
use strict;
use warnings;
use PDL;
use Carp;
use Scalar::Util qw(blessed);
use AI::MXNet::Function::Parameters;
use Exporter;
use base qw(Exporter);
@AI::MXNet::TestUtils::EXPORT_OK = qw(same reldiff almost_equal GetMNIST_ubyte
                                      GetCifar10 pdl_maximum pdl_minimum mlp2 conv
                                      check_consistency zip assert enumerate same_array dies_like);
use constant default_numerical_threshold => 1e-6;
=head1 NAME

    AI::MXNet::TestUtils - Convenience subs used in tests.

=head2 same

    Test if two pdl arrays are the same

    a : pdl
    b : pdl

func same(PDL $a, PDL $b)
    return ($a != $b)->sum == 0;

use strict;
use warnings;
use Test::More tests => 1;
BEGIN { use_ok('AI::MXNet') };

use strict;
use warnings;
use Test::More tests => 14;
use AI::MXNet qw(mx);
use Storable;

sub contains
    my ($x, $y) = @_;
    while(my ($k, $v) = each %$x)
        return 0 unless exists $y->{$k};
        if(ref $y->{$k} and ref $y->{$k} eq 'HASH')

use strict;
use warnings;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(GetMNIST_ubyte);
use Test::More tests => 1;

## speed up the tests when gpu present
my $gpu_present = (`perl -e 'use AI::MXNet qw(mx); print mx->nd->ones([1], ctx => mx->gpu(0))->asscalar' 2>/dev/null` eq '1');

# symbol net
my $batch_size = 100;

### model
my $data = mx->symbol->Variable('data');
my $conv1= mx->symbol->Convolution(data => $data, name => 'conv1', num_filter => 32, kernel => [3,3], stride => [2,2]);

use strict;
use warnings;
use Test::More tests => 2283;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(reldiff pdl_maximum pdl_minimum);
use PDL;

sub check_bind_with_uniform
    my ($uf, $gf, $dim, $sf, $lshape, $rshape) = @_;
    my $shape = (random($dim)*int(1000**(1.0/$dim))+1)->floor->unpdl;
    my $lhs = mx->symbol->Variable('lhs');
    my $rhs = mx->symbol->Variable('rhs');
    my $ret;
    if(defined $sf)

use strict;
use warnings;
use Test::More tests => 18;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(mlp2);

sub _test_shapes
    my ($sym, $arg_shapes, %expected_shapes) = @_;
    my %arg_shape_dict;
    @arg_shape_dict{ @{ $sym->list_arguments() } } = @{ $arg_shapes };
    while(my ($k, $v) = each %expected_shapes)
        is_deeply($arg_shape_dict{$k}, $v);

use AI::MXNet qw(mx);
use Test::More tests => 31;
use AI::MXNet::TestUtils qw(same reldiff GetMNIST_ubyte GetCifar10);
use PDL;
use PDL::Types;
use PDL::NiceSlice;

sub test_Cifar10Rec()
    my $dataiter = mx->io->ImageRecordIter({

use strict;
use warnings;
use Test::More tests => 1;
use AI::MXNet qw(mx);
use Time::HiRes qw(time);

sub run_imageiter
    my ($path_rec, $n, $batch_size) = @_;
    $batch_size //= 32;
    my $data = mx->img->ImageIter(
        data_shape=>[3, 224, 224],

use strict;
use warnings;
use Test::More tests => 38;
use AI::MXNet qw(mx);

my $shape = [4, 4];
my $keys  = [5,7,9];

sub init_kv
    # init kv
    my $kv = mx->kv->create();
    # single

use strict;
use warnings;
use Test::More tests => 4;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(reldiff);
use AI::MXNet::Base;

sub test_chain
    my $ctx1 = mx->cpu(0);
    my $ctx2 = mx->cpu(1);
    my $n = 2;
    my $data1 = mx->sym->Variable('data1');
    my $data2 = mx->sym->Variable('data2');
    my $data3 = mx->sym->Variable('data2');

use strict;
use warnings;
use Test::More tests => 257;
use AI::MXNet qw(mx);
use AI::MXNet::Base;
use AI::MXNet::TestUtils qw(almost_equal enumerate same_array dies_like);
use Data::Dumper;

sub test_module_layout
    my $sym = mx->sym->Variable('data');
    $sym = mx->sym->Activation(data=>$sym, act_type=>'relu', __layout__=>'TNC');

    my $dshape = [3, 8, 7];
    my $mod = mx->mod->Module(

use strict;
use warnings;
use Test::More tests => 10;
use AI::MXNet qw(mx);
use AI::MXNet::Base;

sub test_ctx_group
    my ($data, $fc1, $act1);
        local($mx::AttrScope) = mx->AttrScope(ctx_group=>'stage1');
        $data = mx->symbol->Variable('data');
        $fc1  = mx->symbol->FullyConnected(data => $data, name=>'fc1', num_hidden=>128);

use strict;
use warnings;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(almost_equal);
use Test::More tests => 10;

sub test_ndarray_reshape
    my $tensor = mx->nd->array([[[1, 2], [3, 4]],
                                [[5, 6], [7, 8]]]);
    my $true_res = mx->nd->arange(stop => 8) + 1;
    is_deeply($tensor->reshape([-1])->aspdl->unpdl, $true_res->aspdl->unpdl);
    $true_res  = mx->nd->array([[1, 2, 3, 4],
                                [5, 6, 7, 8]]);
    is_deeply($tensor->reshape([2, -1])->aspdl->unpdl, $true_res->aspdl->unpdl);

                $weight32 += $mom;
        my $tmp = $weight32->astype($weight->dtype);

package main;
use Test::More tests => 1314;
use AI::MXNet::Base;
use PDL::NiceSlice;
use AI::MXNet::TestUtils qw(same reldiff almost_equal);
use AI::MXNet::Function::Parameters;

func compare_optimizer($opt1, $opt2, $shape, $dtype)
    my $w1 = mx->random->uniform({shape => $shape, dtype=>$dtype});
    my $g1 = mx->random->uniform({shape => $shape, dtype=>$dtype});

    my $w2 = $w1->copyto(mx->cpu());
    my $g2 = $g1->copyto(mx->cpu());

use strict;
use warnings;
use Test::More tests => 8;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(same);

sub check_with_device
    my ($device)     = @_;
    my ($a, $b)      = (-10, 10);
    my ($mu, $sigma) = (10, 2);
    my $shape        = [100, 100];
    my $ret1 = mx->random->normal($mu, $sigma, $shape, { ctx => $device });
    my $un1  = mx->random->uniform($a, $b, $shape, { ctx => $device });

use strict;
use warnings;
use AI::MXNet qw(mx);
use Test::More tests => 1711;
use File::Temp qw/tempfile/;
use PDL;

sub test_recordio
    my ($fd, $frec) = tempfile();
    my $N = 255;

    my $writer = mx->recordio->MXRecordIO($frec, 'w');
    for my $i (0..$N-1)

use strict;
use warnings;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(same);
use PDL;
use Test::More tests => 54;

sub test_rnn
    my $cell = mx->rnn->RNNCell(100, prefix=>'rnn_');
    my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);

use strict;
use warnings;
use Test::More tests => 98;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(mlp2 conv check_consistency zip assert enumerate);
use Storable qw(freeze thaw);
use PDL;

sub test_symbol_compose
    my $data = mx->symbol->Variable('data');
    my $net1 = mx->symbol->FullyConnected(data=>$data, name=>'fc1', num_hidden=>10);
    $net1 = mx->symbol->FullyConnected(data=>$net1, name=>'fc2', num_hidden=>100);
    is_deeply($net1->list_arguments(), ['data',
                              'fc1_weight', 'fc1_bias',

