AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Optimizer.pm  view on Meta::CPAN

             $self->clip_gradient
        );
    }
    # warming momentum schedule
    my $momentum_t    = $self->beta1 * (1 - 0.5 * (0.96**($t * $self->schedule_decay)));
    my $momentum_t_1  = $self->beta1 * (1 - 0.5 * (0.96**(($t + 1) * $self->schedule_decay)));
    $self->m_schedule = $self->m_schedule * $momentum_t;
    my $m_schedule_next  = $self->m_schedule * $momentum_t_1;

    # update m_t and v_t
    my ($m_t, $v_t) = @{ $state };
    $m_t .= $self->beta1 * $m_t + (1 - $self->beta1) * $grad;
    $v_t .= $self->beta2 * $v_t + (1 - $self->beta2) * $grad * $grad;

    my $grad_prime = $grad / (1 - $self->m_schedule);
    my $m_t_prime  = $m_t  / (1 - $m_schedule_next);
    my $v_t_prime  = $v_t  / (1 - $self->beta2**$t);
    my $m_t_bar    = (1 - $momentum_t) * $grad_prime + $momentum_t_1 * $m_t_prime;

    # update weight
    $weight -= $lr * $m_t_bar / (sqrt($v_t_prime) + $self->epsilon);
}

__PACKAGE__->register;

# updater for kvstore
package AI::MXNet::Updater;
use Mouse;
use Storable qw(thaw freeze);
use overload "&{}" => sub { my $self = shift; sub { $self->call(@_) } },
             fallback => 1;

has "optimizer"     => (is => "rw", isa => "AI::MXNet::Optimizer");
has "states"        => (is => "rw", isa => "HashRef", default => sub { +{} });
has "states_synced" => (is => "rw", isa => "HashRef", default => sub { +{} });

method call(Index $index, AI::MXNet::NDArray $grad, AI::MXNet::NDArray $weight)
{
    if(not exists $self->states->{ $index })
    {
        $self->states->{ $index } = $self->optimizer->create_state($index, $weight);
        $self->states_synced->{ $index } = 1;
    }
    elsif(not $self->states_synced->{ $index })
    {
        $self->states->{ $index } = $self->sync_state_context($self->states->{ $index }, $weight->context);
        $self->states_synced->{ $index } = 1;
    }
    $self->optimizer->update($index, $weight, $grad, $self->states->{ $index });
}
*slice = *call;

method sync_state_context(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $state, AI::MXNet::Context $context)
{
    if(blessed $state)
    {
        return $state->as_in_context($context);
    }
    elsif(ref $state)
    {
        return [map { $self->sync_state_context($_, $context) } @{ $state }];
    }
    return $state;
}

method set_states($states)
{
    my $thawed_states = thaw($states);
    $self->states($thawed_states);
    %{ $self->states_synced } = map { $_ => 0 } keys %{ $thawed_states };
}

method get_states()
{
    return freeze($self->states);
}

package AI::MXNet::Optimizer;

method get_updater(AI::MXNet::Optimizer $optimizer)
{
    return AI::MXNet::Updater->new(optimizer => $optimizer);
}

1;



( run in 0.653 second using v1.01-cache-2.11-cpan-39bf76dae61 )