AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Optimizer.pm view on Meta::CPAN
)
{
my $lr = $self->_get_lr($index);
my $wd = $self->_get_wd($index);
$self->_update_count($index);
my $kwargs = {
out => $weight,
lr => $lr,
wd => $wd,
%{ $self->kwargs }
};
my $use_multi_precision = ref($state) eq 'ARRAY';
if(not $use_multi_precision)
{
if(defined $state)
{
AI::MXNet::NDArray->sgd_mom_update(
$weight, $grad, $state, $kwargs
);
}
else
{
AI::MXNet::NDArray->sgd_update(
$weight, $grad, $kwargs
);
}
}
else
{
if(defined $state->[0])
{
AI::MXNet::NDArray->mp_sgd_mom_update(
$weight, $grad, $state->[0], $state->[1], $kwargs
);
}
else
{
AI::MXNet::NDArray->mp_sgd_update(
$weight, $grad, $state->[1], $kwargs
);
}
}
}
__PACKAGE__->register;
package AI::MXNet::DCASGD;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::Optimizer';
=head1 NAME
AI::MXNet::DCASGD - DCASGD optimizer with momentum and weight regularization.
=cut
=head1 DESCRIPTION
DCASGD optimizer with momentum and weight regularization.
Implements paper "Asynchronous Stochastic Gradient Descent with
Delay Compensation for Distributed Deep Learning"
Parameters
----------
learning_rate : float, optional
learning_rate of SGD
momentum : float, optional
momentum value
lamda : float, optional
scale DC value
wd : float, optional
L2 regularization coefficient add to all the weights
rescale_grad : float, optional
rescaling factor of gradient. Normally should be 1/batch_size.
clip_gradient : float, optional
clip gradient in range [-clip_gradient, clip_gradient]
param_idx2name : hash ref of string/int to float, optional
special treat weight decay in parameter ends with bias, gamma, and beta
=cut
has 'momentum' => (is => 'ro', isa => 'Num', default => 0);
has 'lamda' => (is => 'ro', isa => 'Num', default => 0.04);
has 'weight_previous' => (is => 'rw', init_arg => undef);
sub BUILD
{
my $self = shift;
$self->weight_previous({});
}
method create_state(Index $index, AI::MXNet::NDArray $weight)
{
return [
$self->momentum ? AI::MXNet::NDArray->zeros(
$weight->shape, ctx => $weight->context, dtype => $weight->dtype
) : undef,
$weight->copy
];
}
method update(
Index $index,
AI::MXNet::NDArray $weight,
AI::MXNet::NDArray $grad,
Maybe[AI::MXNet::NDArray] $state
)
{
my $lr = $self->_get_lr($index);
my $wd = $self->_get_wd($index);
$self->_update_count($index);
$grad *= $self->rescale_grad;
if($self->clip_gradient)
{
$grad = AI::MXNet::NDArray->clip(
$grad,
( run in 1.376 second using v1.01-cache-2.11-cpan-39bf76dae61 )