AI-MXNet

 view release on metacpan or  search on metacpan

t/test_optimizers.t  view on Meta::CPAN

                $weight .= (1 - $lr*$wd)*$weight - $lr*$self->rescale_grad*$grad;
            }
        }
        else
        {
            my $mom = $state;
            if(defined $self->clip_gradient)
            {
                $mom .= ($self->momentum*$mom - $lr*$wd*$weight -
                    $lr * mx->nd->clip($grad*$self->rescale_grad, -$self->clip_gradient, $self->clip_gradient)
                );
                $weight += $mom;
            }
            else
            {
                $mom .= $self->momentum*$mom - $lr*$wd*$weight - $lr*$self->rescale_grad*$grad;
                $weight += $mom;
            }
        }
    }
    else
    {
        my $grad32 = mx->nd->array($grad, ctx=>$grad->context, dtype=>'float32');
        my $mom = $state->[0];
        my $weight32 = $state->[1];
        if($self->momentum == 0)
        {
            if(defined $self->clip_gradient)
            {
                $weight32 .= ((1 - $lr*$wd)*$weight32 -
                    $lr * mx->nd->clip($grad32*$self->rescale_grad, -$self->clip_gradient, $self->clip_gradient)
                );
            }
            else
            {
                $weight32 .= (1 - $lr*$wd)*$weight32 - $lr*$self->rescale_grad*$grad32;
            }
        }
        else
        {
            if(defined $self->clip_gradient)
            {
                $mom .= ($self->momentum*$mom - $lr*$wd*$weight32 -
                    $lr * mx->nd->clip($grad32*$self->rescale_grad, -$self->clip_gradient, $self->clip_gradient)
                );
                $weight32 += $mom;
            }
            else
            {
                $mom .= $self->momentum*$mom - $lr*$wd*$weight32 - $lr*$self->rescale_grad*$grad32;
                $weight32 += $mom;
            }
        }
        my $tmp = $weight32->astype($weight->dtype);
        $tmp->copyto($weight);
    }
}


package main;
use Test::More tests => 1314;
use AI::MXNet::Base;
use PDL::NiceSlice;
use AI::MXNet::TestUtils qw(same reldiff almost_equal);
use AI::MXNet::Function::Parameters;

func compare_optimizer($opt1, $opt2, $shape, $dtype)
{
    my $w1 = mx->random->uniform({shape => $shape, dtype=>$dtype});
    my $g1 = mx->random->uniform({shape => $shape, dtype=>$dtype});

    my $w2 = $w1->copyto(mx->cpu());
    my $g2 = $g1->copyto(mx->cpu());

    my $state1 = $opt1->create_state(0, $w1);
    my $state2 = $opt2->create_state(0, $w2);
    zip(
        sub {
            my ($s1, $s2) = @_;
            ok(same($s1->aspdl, $s2->aspdl)) if defined $s1 and defined $s2;
        },
        ref $state1 eq 'ARRAY' ? $state1 : [$state1], ref $state2 eq 'ARRAY' ? $state2 : [$state2]
    ) if defined $state1 and defined $state2;

    $opt1->update(0, $w1, $g1, $state1);
    $opt2->update(0, $w2, $g2, $state2);
    zip(
        sub {
            my ($s1, $s2) = @_;
            ok(reldiff($s1->aspdl, $s2->aspdl) < 1e-5) if defined $s1 and defined $s2;
        },
        ref $state1 eq 'ARRAY' ? $state1 : [$state1], ref $state2 eq 'ARRAY' ? $state2 : [$state2]
    ) if defined $state1 and defined $state2;
    ok(reldiff($w1->aspdl, $w2->aspdl) < 1e-5);
}

func test_adam()
{
    mx->random->seed(0);
    my $opt1 = 'PerlAdam';
    my $opt2 = 'AI::MXNet::Adam';
    my $shape = [3, 4, 5];
    my @kwargs = ({},
              {'clip_gradient'=> 0.5},
              {'clip_gradient'=> 0.1},
              {'rescale_grad'=> 0.1});
    for my $kwarg (@kwargs)
    {
        compare_optimizer($opt1->new(%$kwarg), $opt2->new(wd => 0.9, %$kwarg), $shape, 'float32');
    }
}

func test_rms()
{
    mx->random->seed(0);
    my $opt1 = 'PerlRMSProp';
    my $opt2 = 'AI::MXNet::RMSProp';
    my $shape = [3, 4, 5];
    my @kwargs = ({},
              {clip_gradient => 0.5},
              {clip_gradient => 0.4, rescale_grad => 0.14},
              {rescale_grad  => 0.8},
              {clip_gradient => 0.5, wd => 0.07},
              {clip_gradient => 0.4, rescale_grad => 0.14, wd => 0.03},
              {rescale_grad  => 0.8, wd => 0.05},
              {centered => 1},
              {clip_gradient => 0.5, centered => 1},
              {clip_gradient => 0.4, rescale_grad => 0.14, centered => 1},
              {rescale_grad  => 0.8, centered => 1},
              {clip_gradient => 0.5, wd => 0.07, centered => 1},
              {clip_gradient => 0.4, rescale_grad => 0.14, wd => 0.03, centered => 1},
              {rescale_grad  => 0.8, wd => 0.05, centered => 1},
              {clip_gradient => 0.5, clip_weights => 0.01},
              {clip_gradient => 0.4, rescale_grad => 0.14, clip_weights => 0.01},
              {rescale_grad  => 0.8, clip_weights => 0.01},
              {clip_gradient => 0.5, wd => 0.07, clip_weights => 0.01},
              {clip_gradient => 0.4, rescale_grad => 0.14, wd => 0.03, clip_weights => 0.01},
              {rescale_grad  => 0.8, wd => 0.05, clip_weights => 0.01},
              {centered => 1, clip_weights => 0.01},
              {clip_gradient => 0.5, centered => 1, clip_weights => 0.01},
              {clip_gradient => 0.4, rescale_grad => 0.14, centered => 1, clip_weights => 0.01},
              {rescale_grad  => 0.8, centered => 1, clip_weights => 0.01},
              {clip_gradient => 0.5, wd => 0.07, centered => 1, clip_weights => 0.01},
              {clip_gradient => 0.4, rescale_grad => 0.14, wd => 0.03, centered => 1, clip_weights => 0.01},
              {rescale_grad  => 0.8, wd => 0.05, centered => 1, clip_weights => 0.01});
    for my $kwarg (@kwargs)
    {
        compare_optimizer($opt1->new(%$kwarg), $opt2->new(%$kwarg), $shape, 'float32');
    }
}


sub test_sgd
{
    mx->random->seed(0);
    my $opt1 = 'PerlSGD';
    my $opt2 = mx->optimizer->SGD;
    my $shape = [3, 4, 5];
    my @mom_options = ({}, {momentum => 0.9});
    my @cg_options = ({}, {clip_gradient => 0.4}, {clip_gradient => 0.5});
    my @rg_options = ({}, {rescale_grad => 0.14}, {rescale_grad => 0.8});
    my @wd_options = ({}, {wd => 0.03}, {wd => 0.05}, {wd => 0.07});
    my @mp_options = ({}, {multi_precision => 0}, {multi_precision => 1});
    for my $dtype(qw/float16 float32 float64/)
    {
        for my $mom_option (@mom_options)
        {
            for my $cg_option (@cg_options)
            {
                for my $rg_option (@rg_options)
                {
                    for my $wd_option (@wd_options)
                    {
                        for my $mp_option (@mp_options)
                        {
                            my %kwarg;
                            %kwarg = (%kwarg, %$mom_option);
                            %kwarg = (%kwarg, %$cg_option);
                            %kwarg = (%kwarg, %$rg_option);
                            %kwarg = (%kwarg, %$wd_option);
                            %kwarg = (%kwarg, %$mp_option);
                            next if (
                                $dtype eq 'float16'
                                    and
                                (not exists $kwarg{multi_precision} or not $kwarg{multi_precision})
                            );
                            compare_optimizer($opt1->new(%kwarg), $opt2->new(%kwarg), $shape, $dtype);
                        }
                    }
                }
            }
        }
    }
}

func test_lr_wd_mult()
{
    my $data = mx->sym->Variable('data');
    my $bias = mx->sym->Variable('fc1_bias', lr_mult => 1.0);
    my $fc1  = mx->sym->FullyConnected({ data => $data, bias => $bias, name => 'fc1', num_hidden => 10, lr_mult => 0 });
    my $fc2  = mx->sym->FullyConnected({ data => $fc1, name => 'fc2', num_hidden => 10, wd_mult => 0.5 });

    my $mod = mx->mod->new(symbol => $fc2, label_names => undef);
    $mod->bind(data_shapes => [['data', [5,10]]]);
    $mod->init_params(initializer => mx->init->Uniform(scale => 1.0));
    $mod->init_optimizer(optimizer_params => { learning_rate => "1.0" });
    my %args1 = %{ ($mod->get_params())[0] };
    for my $k (keys %args1)
    {
        $args1{$k} = $args1{$k}->aspdl;
    }
    $mod->forward(AI::MXNet::DataBatch->new(data=>[mx->random->uniform({low=>-1.0, high=>1.0, shape=>[5,10]})], label=>undef), is_train=>1);
    $mod->backward($mod->get_outputs());
    $mod->update();
    my %args2 = %{ ($mod->get_params())[0] };
    for my $k (keys %args2)
    {
        $args2{$k} = $args2{$k}->aspdl;
    }
    is_deeply($mod->_p->_optimizer->lr_mult, { fc1_bias => 1, fc1_weight => 0 }, "lr_mult");
    is_deeply($mod->_p->_optimizer->wd_mult, { fc2_bias => 0.5, fc2_weight => 0.5, fc1_bias => 0, }, "wd_mult");
    ok(almost_equal($args1{fc1_weight}, $args2{fc1_weight}, 1e-10), "fc1_weight");
    ok(!almost_equal($args1{fc1_bias}, $args2{fc1_bias}, 1e-1), "fc1_bias");
    ok(!almost_equal($args1{fc2_weight}, $args2{fc2_weight}, 1e-1), "fc2_weight");
}

test_adam();
test_rms();
test_sgd();
test_lr_wd_mult();



( run in 0.337 second using v1.01-cache-2.11-cpan-39bf76dae61 )