AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Optimizer.pm  view on Meta::CPAN

)
{
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    $self->_update_count($index);
    my $kwargs = {
        out => $weight,
        lr  => $lr,
        wd  => $wd,
        %{ $self->kwargs }
    };
    my $use_multi_precision = ref($state) eq 'ARRAY';
    if(not $use_multi_precision)
    {
        if(defined $state)
        {
            AI::MXNet::NDArray->sgd_mom_update(
                $weight, $grad, $state, $kwargs
            );
        }
        else
        {
            AI::MXNet::NDArray->sgd_update(
                $weight, $grad, $kwargs
            );
        }
    }
    else
    {
        if(defined $state->[0])
        {
            AI::MXNet::NDArray->mp_sgd_mom_update(
                $weight, $grad, $state->[0], $state->[1], $kwargs
            );
        }
        else
        {
            AI::MXNet::NDArray->mp_sgd_update(
                $weight, $grad, $state->[1], $kwargs
            );
        }
    }
}

__PACKAGE__->register;

package AI::MXNet::DCASGD;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::Optimizer';

=head1 NAME

    AI::MXNet::DCASGD - DCASGD optimizer with momentum and weight regularization.
=cut

=head1 DESCRIPTION

    DCASGD optimizer with momentum and weight regularization.

    Implements paper "Asynchronous Stochastic Gradient Descent with
                    Delay Compensation for Distributed Deep Learning"

    Parameters
    ----------
    learning_rate : float, optional
        learning_rate of SGD

    momentum : float, optional
       momentum value

    lamda : float, optional
       scale DC value

    wd : float, optional
        L2 regularization coefficient add to all the weights

    rescale_grad : float, optional
        rescaling factor of gradient. Normally should be 1/batch_size.

    clip_gradient : float, optional
        clip gradient in range [-clip_gradient, clip_gradient]

    param_idx2name : hash ref of string/int to float, optional
        special treat weight decay in parameter ends with bias, gamma, and beta
=cut
has 'momentum'        => (is => 'ro', isa => 'Num', default => 0);
has 'lamda'           => (is => 'ro', isa => 'Num', default => 0.04);
has 'weight_previous' => (is => 'rw', init_arg => undef);

sub BUILD
{
    my $self = shift;
    $self->weight_previous({});
}

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
        return [
            $self->momentum ? AI::MXNet::NDArray->zeros(
                $weight->shape, ctx => $weight->context, dtype => $weight->dtype
            ) : undef,
            $weight->copy
        ];
}

method update(
    Index                     $index,
    AI::MXNet::NDArray        $weight,
    AI::MXNet::NDArray        $grad,
    Maybe[AI::MXNet::NDArray] $state
)
{
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    $self->_update_count($index);
    $grad *= $self->rescale_grad;
    if($self->clip_gradient)
    {
        $grad = AI::MXNet::NDArray->clip(
            $grad,



( run in 1.376 second using v1.01-cache-2.11-cpan-39bf76dae61 )