AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Optimizer.pm view on Meta::CPAN
$self->clip_gradient
);
}
# warming momentum schedule
my $momentum_t = $self->beta1 * (1 - 0.5 * (0.96**($t * $self->schedule_decay)));
my $momentum_t_1 = $self->beta1 * (1 - 0.5 * (0.96**(($t + 1) * $self->schedule_decay)));
$self->m_schedule = $self->m_schedule * $momentum_t;
my $m_schedule_next = $self->m_schedule * $momentum_t_1;
# update m_t and v_t
my ($m_t, $v_t) = @{ $state };
$m_t .= $self->beta1 * $m_t + (1 - $self->beta1) * $grad;
$v_t .= $self->beta2 * $v_t + (1 - $self->beta2) * $grad * $grad;
my $grad_prime = $grad / (1 - $self->m_schedule);
my $m_t_prime = $m_t / (1 - $m_schedule_next);
my $v_t_prime = $v_t / (1 - $self->beta2**$t);
my $m_t_bar = (1 - $momentum_t) * $grad_prime + $momentum_t_1 * $m_t_prime;
# update weight
$weight -= $lr * $m_t_bar / (sqrt($v_t_prime) + $self->epsilon);
}
__PACKAGE__->register;
# updater for kvstore
package AI::MXNet::Updater;
use Mouse;
use Storable qw(thaw freeze);
use overload "&{}" => sub { my $self = shift; sub { $self->call(@_) } },
fallback => 1;
has "optimizer" => (is => "rw", isa => "AI::MXNet::Optimizer");
has "states" => (is => "rw", isa => "HashRef", default => sub { +{} });
has "states_synced" => (is => "rw", isa => "HashRef", default => sub { +{} });
method call(Index $index, AI::MXNet::NDArray $grad, AI::MXNet::NDArray $weight)
{
if(not exists $self->states->{ $index })
{
$self->states->{ $index } = $self->optimizer->create_state($index, $weight);
$self->states_synced->{ $index } = 1;
}
elsif(not $self->states_synced->{ $index })
{
$self->states->{ $index } = $self->sync_state_context($self->states->{ $index }, $weight->context);
$self->states_synced->{ $index } = 1;
}
$self->optimizer->update($index, $weight, $grad, $self->states->{ $index });
}
*slice = *call;
method sync_state_context(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $state, AI::MXNet::Context $context)
{
if(blessed $state)
{
return $state->as_in_context($context);
}
elsif(ref $state)
{
return [map { $self->sync_state_context($_, $context) } @{ $state }];
}
return $state;
}
method set_states($states)
{
my $thawed_states = thaw($states);
$self->states($thawed_states);
%{ $self->states_synced } = map { $_ => 0 } keys %{ $thawed_states };
}
method get_states()
{
return freeze($self->states);
}
package AI::MXNet::Optimizer;
method get_updater(AI::MXNet::Optimizer $optimizer)
{
return AI::MXNet::Updater->new(optimizer => $optimizer);
}
1;
( run in 0.653 second using v1.01-cache-2.11-cpan-39bf76dae61 )