AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Optimizer.pm  view on Meta::CPAN

package AI::MXNet::Optimizer;
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::NDArray;
use AI::MXNet::Random;
use List::Util qw(max);

=head1 NAME

    AI::MXNet::Optimizer - Common Optimization algorithms with regularizations.

=head1  DESCRIPTION

    Common Optimization algorithms with regularizations.
=cut

use Mouse;
use AI::MXNet::Function::Parameters;
my %opt_registry;
method get_opt_registry()
{
    return \%opt_registry;
}

method register()
{
    my $name = $self;
    ($name) = $name =~ /::(\w+)$/;
    {  no strict 'refs'; *{__PACKAGE__."::$name"} = sub { $self }; }
    $name = lc $name;
    if(exists $opt_registry{ $name })
    {
        my $existing = $opt_registry{ $name };
        warn(
            "WARNING: New optimizer $self.$name" 
            ."is overriding existing optimizer $existing.$name"
        );
    }
    $opt_registry{ $name } = $self;
}

=head2 create_optimizer

        Create an optimizer with specified name.

        Parameters
        ----------
        name: str
            Name of required optimizer. Should be the name
            of a subclass of Optimizer. Case insensitive.

        rescale_grad : float
            Rescaling factor on gradient. Normally should be 1/batch_size.

        kwargs: dict
            Parameters for optimizer

        Returns
        -------
        opt : Optimizer
            The result optimizer.
=cut

method create_optimizer(Str $name, %kwargs)
{
    if(exists $opt_registry{ lc $name })
    {
        my $rescale_grad = delete($kwargs{rescale_grad})//1;
        return $opt_registry{ lc $name }->new(
            rescale_grad => $rescale_grad,
            %kwargs
        );
    }
    confess("Cannot find optimizer $name");
}

*create = \&create_optimizer;

has 'rescale_grad'        => (is => "rw", isa => "Num", default=>1);
has 'lr'                  => (is => "rw", isa => "Num");
has 'learning_rate'       => (is => "rw", isa => "Num", default => 0.01);
has 'lr_scheduler'        => (is => "rw", isa => "Maybe[AI::MXNet::LRScheduler]");
has 'wd'                  => (is => "rw", isa => "Num", default => 0);
has 'lr_mult'             => (is => "rw", isa => "HashRef", default => sub { +{} });
has 'wd_mult'             => (is => "rw", isa => "HashRef", , default => sub { +{} });
has 'num_update'          => (is => "rw", isa => "Int");
has 'begin_num_update'    => (is => "rw", isa => "Int", default => 0);
has '_index_update_count' => (is => "rw", isa => "HashRef", default => sub { +{} });
has 'clip_gradient'       => (is => "rw", isa => "Maybe[Num]");
has 'param_idx2name'      => (is => "rw", isa => "HashRef[Str]", default => sub { +{} });
has 'idx2name'            => (is => "rw", isa => "HashRef[Str]");
has 'sym'                 => (is => "rw", isa => "Maybe[AI::MXNet::Symbol]");

sub BUILD
{
    my $self = shift;
    if($self->lr_scheduler)
    {
        $self->lr_scheduler->base_lr($self->learning_rate);
    }
    $self->lr($self->learning_rate);
    $self->num_update($self->begin_num_update);
    $self->idx2name({ %{ $self->param_idx2name } });
    $self->set_lr_mult({});
    $self->set_wd_mult({});
}
# Create additional optimizer state such as momentum.
# override in implementations.
method create_state($index, $weight){}



( run in 0.776 second using v1.01-cache-2.11-cpan-39bf76dae61 )