AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/LRScheduler.pm  view on Meta::CPAN

package AI::MXNet::LRScheduler;
use strict;
use warnings;
use Mouse;
use AI::MXNet::Function::Parameters;
use AI::MXNet::Logging;
use overload "&{}" => sub { my $self = shift; sub { $self->call(@_) } },
             fallback => 1;

=head1 NAME

    AI::MXNet::LRScheduler - The adaptive scheduler of the learning rate.
=cut

=head1 DESCRIPTION

    Learning rate scheduler, which adaptively changes the learning rate based on the
    progress.
=cut

=head2 new

    base_lr : float (optional, default 0.01)
    the initial learning rate
=cut

has 'base_lr' => (is => 'rw', isa => 'Num', default => 0.01);

=head2 call

    Call to schedule current learning rate

    The training progress is presented by num_update, which can be roughly
    viewed as the number of minibatches executed so far. Its value is
    non-decreasing, and increases at most by one.

    The exact value is the upper bound of the number of updates applied to
    a weight/index

    See more details in https://github.com/dmlc/mxnet/issues/625

    Parameters
    ----------
    num_update: int
        the maximal number of updates applied to a weight.
=cut

package AI::MXNet::FactorScheduler;

=head1 NAME

    AI::MXNet::FactorScheduler - Reduces the learning rate by a factor.

=head1 DESCRIPTION

    Reduces the learning rate by a factor each step.
    Assume the weight has been updated by n times, then the learning rate will
    be base_lr * factor^(floor(n/step))

    Parameters
    ----------
    step: int
        schedule the learning rate update after n updates
    factor: float
        the factor by which to reduce the learning rate.
=cut
use Mouse;
extends 'AI::MXNet::LRScheduler';

has 'step'            => (is => 'ro', isa => 'Int', required => 1);
has 'factor'          => (is => 'ro', isa => 'Num', default  => 1);
has 'count'           => (is => 'rw', isa => 'Int', default  => 1);
has 'stop_factor_lr'  => (is => 'ro', isa => 'Num', default  => 1e-8);

sub BUILD
{
    my $self = shift;
    confess("Schedule step must be greater or equal than 1")
        if $self->step < 1;
    confess("Factor must be no more than 1 to make lr reduce")
        if $self->factor > 1;
}

method call(Int $num_update)
{
    # NOTE: use while rather than if  (for continuing training via load_epoch)
    while($num_update > $self->count + $self->step)
    {
        $self->count($self->count + $self->step);
        $self->base_lr($self->base_lr * $self->factor);
        if($self->base_lr < $self->stop_factor_lr)
        {
            $self->base_lr($self->stop_factor_lr);



( run in 0.839 second using v1.01-cache-2.11-cpan-39bf76dae61 )