AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Optimizer.pm view on Meta::CPAN
package AI::MXNet::Optimizer;
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::NDArray;
use AI::MXNet::Random;
use List::Util qw(max);
=head1 NAME
AI::MXNet::Optimizer - Common Optimization algorithms with regularizations.
=head1 DESCRIPTION
Common Optimization algorithms with regularizations.
=cut
use Mouse;
use AI::MXNet::Function::Parameters;
my %opt_registry;
method get_opt_registry()
{
return \%opt_registry;
}
method register()
{
my $name = $self;
($name) = $name =~ /::(\w+)$/;
{ no strict 'refs'; *{__PACKAGE__."::$name"} = sub { $self }; }
$name = lc $name;
if(exists $opt_registry{ $name })
{
my $existing = $opt_registry{ $name };
warn(
"WARNING: New optimizer $self.$name"
."is overriding existing optimizer $existing.$name"
);
}
$opt_registry{ $name } = $self;
}
=head2 create_optimizer
Create an optimizer with specified name.
Parameters
----------
name: str
Name of required optimizer. Should be the name
of a subclass of Optimizer. Case insensitive.
rescale_grad : float
Rescaling factor on gradient. Normally should be 1/batch_size.
kwargs: dict
Parameters for optimizer
Returns
-------
opt : Optimizer
The result optimizer.
=cut
method create_optimizer(Str $name, %kwargs)
{
if(exists $opt_registry{ lc $name })
{
my $rescale_grad = delete($kwargs{rescale_grad})//1;
return $opt_registry{ lc $name }->new(
rescale_grad => $rescale_grad,
%kwargs
);
}
confess("Cannot find optimizer $name");
}
*create = \&create_optimizer;
has 'rescale_grad' => (is => "rw", isa => "Num", default=>1);
has 'lr' => (is => "rw", isa => "Num");
has 'learning_rate' => (is => "rw", isa => "Num", default => 0.01);
has 'lr_scheduler' => (is => "rw", isa => "Maybe[AI::MXNet::LRScheduler]");
has 'wd' => (is => "rw", isa => "Num", default => 0);
has 'lr_mult' => (is => "rw", isa => "HashRef", default => sub { +{} });
has 'wd_mult' => (is => "rw", isa => "HashRef", , default => sub { +{} });
has 'num_update' => (is => "rw", isa => "Int");
has 'begin_num_update' => (is => "rw", isa => "Int", default => 0);
has '_index_update_count' => (is => "rw", isa => "HashRef", default => sub { +{} });
has 'clip_gradient' => (is => "rw", isa => "Maybe[Num]");
has 'param_idx2name' => (is => "rw", isa => "HashRef[Str]", default => sub { +{} });
has 'idx2name' => (is => "rw", isa => "HashRef[Str]");
has 'sym' => (is => "rw", isa => "Maybe[AI::MXNet::Symbol]");
sub BUILD
{
my $self = shift;
if($self->lr_scheduler)
{
$self->lr_scheduler->base_lr($self->learning_rate);
}
$self->lr($self->learning_rate);
$self->num_update($self->begin_num_update);
$self->idx2name({ %{ $self->param_idx2name } });
$self->set_lr_mult({});
$self->set_wd_mult({});
}
# Create additional optimizer state such as momentum.
# override in implementations.
method create_state($index, $weight){}
# Update the parameters. override in implementations
method update($index, $weight, $grad, $state){}
# set lr scale is deprecated. Use set_lr_mult instead.
method set_lr_scale($args_lrscale)
{
Carp::cluck("set lr scale is deprecated. Use set_lr_mult instead.");
}
=head2 set_lr_mult
Set individual learning rate multipler for parameters
Parameters
----------
args_lr_mult : dict of string/int to float
set the lr multipler for name/index to float.
setting multipler by index is supported for backward compatibility,
but we recommend using name and symbol.
=cut
method set_lr_mult(HashRef[Num] $args_lr_mult)
{
$self->lr_mult({});
if($self->sym)
{
my $attr = $self->sym->attr_dict();
for my $name (@{ $self->sym->list_arguments() })
{
if(exists $attr->{ $name } and exists $attr->{ $name }{ __lr_mult__ })
{
$self->lr_mult->{ $name } = $attr->{ $name }{ __lr_mult__ };
}
}
}
$self->lr_mult({ %{ $self->lr_mult }, %{ $args_lr_mult } });
}
=head2 set_wd_mult
Set individual weight decay multipler for parameters.
By default wd multipler is 0 for all params whose name doesn't
end with _weight, if param_idx2name is provided.
Parameters
----------
args_wd_mult : dict of string/int to float
set the wd multipler for name/index to float.
setting multipler by index is supported for backward compatibility,
but we recommend using name and symbol.
=cut
method set_wd_mult(HashRef[Num] $args_wd_mult)
{
$self->wd_mult({});
for my $n (values %{ $self->idx2name })
{
if(not $n =~ /(?:_weight|_gamma)$/)
{
$self->wd_mult->{ $n } = 0;
}
( run in 1.982 second using v1.01-cache-2.11-cpan-39bf76dae61 )