AI-MXNet
view release on metacpan or search on metacpan
t/test_optimizers.t view on Meta::CPAN
$weight .= (1 - $lr*$wd)*$weight - $lr*$self->rescale_grad*$grad;
}
}
else
{
my $mom = $state;
if(defined $self->clip_gradient)
{
$mom .= ($self->momentum*$mom - $lr*$wd*$weight -
$lr * mx->nd->clip($grad*$self->rescale_grad, -$self->clip_gradient, $self->clip_gradient)
);
$weight += $mom;
}
else
{
$mom .= $self->momentum*$mom - $lr*$wd*$weight - $lr*$self->rescale_grad*$grad;
$weight += $mom;
}
}
}
else
{
my $grad32 = mx->nd->array($grad, ctx=>$grad->context, dtype=>'float32');
my $mom = $state->[0];
my $weight32 = $state->[1];
if($self->momentum == 0)
{
if(defined $self->clip_gradient)
{
$weight32 .= ((1 - $lr*$wd)*$weight32 -
$lr * mx->nd->clip($grad32*$self->rescale_grad, -$self->clip_gradient, $self->clip_gradient)
);
}
else
{
$weight32 .= (1 - $lr*$wd)*$weight32 - $lr*$self->rescale_grad*$grad32;
}
}
else
{
if(defined $self->clip_gradient)
{
$mom .= ($self->momentum*$mom - $lr*$wd*$weight32 -
$lr * mx->nd->clip($grad32*$self->rescale_grad, -$self->clip_gradient, $self->clip_gradient)
);
$weight32 += $mom;
}
else
{
$mom .= $self->momentum*$mom - $lr*$wd*$weight32 - $lr*$self->rescale_grad*$grad32;
$weight32 += $mom;
}
}
my $tmp = $weight32->astype($weight->dtype);
$tmp->copyto($weight);
}
}
package main;
use Test::More tests => 1314;
use AI::MXNet::Base;
use PDL::NiceSlice;
use AI::MXNet::TestUtils qw(same reldiff almost_equal);
use AI::MXNet::Function::Parameters;
func compare_optimizer($opt1, $opt2, $shape, $dtype)
{
my $w1 = mx->random->uniform({shape => $shape, dtype=>$dtype});
my $g1 = mx->random->uniform({shape => $shape, dtype=>$dtype});
my $w2 = $w1->copyto(mx->cpu());
my $g2 = $g1->copyto(mx->cpu());
my $state1 = $opt1->create_state(0, $w1);
my $state2 = $opt2->create_state(0, $w2);
zip(
sub {
my ($s1, $s2) = @_;
ok(same($s1->aspdl, $s2->aspdl)) if defined $s1 and defined $s2;
},
ref $state1 eq 'ARRAY' ? $state1 : [$state1], ref $state2 eq 'ARRAY' ? $state2 : [$state2]
) if defined $state1 and defined $state2;
$opt1->update(0, $w1, $g1, $state1);
$opt2->update(0, $w2, $g2, $state2);
zip(
sub {
my ($s1, $s2) = @_;
ok(reldiff($s1->aspdl, $s2->aspdl) < 1e-5) if defined $s1 and defined $s2;
},
ref $state1 eq 'ARRAY' ? $state1 : [$state1], ref $state2 eq 'ARRAY' ? $state2 : [$state2]
) if defined $state1 and defined $state2;
ok(reldiff($w1->aspdl, $w2->aspdl) < 1e-5);
}
func test_adam()
{
mx->random->seed(0);
my $opt1 = 'PerlAdam';
my $opt2 = 'AI::MXNet::Adam';
my $shape = [3, 4, 5];
my @kwargs = ({},
{'clip_gradient'=> 0.5},
{'clip_gradient'=> 0.1},
{'rescale_grad'=> 0.1});
for my $kwarg (@kwargs)
{
compare_optimizer($opt1->new(%$kwarg), $opt2->new(wd => 0.9, %$kwarg), $shape, 'float32');
}
}
func test_rms()
{
mx->random->seed(0);
my $opt1 = 'PerlRMSProp';
my $opt2 = 'AI::MXNet::RMSProp';
my $shape = [3, 4, 5];
my @kwargs = ({},
{clip_gradient => 0.5},
{clip_gradient => 0.4, rescale_grad => 0.14},
{rescale_grad => 0.8},
{clip_gradient => 0.5, wd => 0.07},
{clip_gradient => 0.4, rescale_grad => 0.14, wd => 0.03},
( run in 0.866 second using v1.01-cache-2.11-cpan-39bf76dae61 )