AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/LRScheduler.pm  view on Meta::CPAN

package AI::MXNet::LRScheduler;
use strict;
use warnings;
use Mouse;
use AI::MXNet::Function::Parameters;
use AI::MXNet::Logging;
use overload "&{}" => sub { my $self = shift; sub { $self->call(@_) } },
             fallback => 1;

=head1 NAME

    AI::MXNet::LRScheduler - The adaptive scheduler of the learning rate.
=cut

=head1 DESCRIPTION

    Learning rate scheduler, which adaptively changes the learning rate based on the
    progress.
=cut

=head2 new

    base_lr : float (optional, default 0.01)
    the initial learning rate
=cut

has 'base_lr' => (is => 'rw', isa => 'Num', default => 0.01);

=head2 call

    Call to schedule current learning rate

    The training progress is presented by num_update, which can be roughly
    viewed as the number of minibatches executed so far. Its value is
    non-decreasing, and increases at most by one.

    The exact value is the upper bound of the number of updates applied to
    a weight/index

    See more details in https://github.com/dmlc/mxnet/issues/625

    Parameters
    ----------
    num_update: int
        the maximal number of updates applied to a weight.
=cut

package AI::MXNet::FactorScheduler;

=head1 NAME

    AI::MXNet::FactorScheduler - Reduces the learning rate by a factor.

=head1 DESCRIPTION

    Reduces the learning rate by a factor each step.
    Assume the weight has been updated by n times, then the learning rate will
    be base_lr * factor^(floor(n/step))

    Parameters
    ----------
    step: int
        schedule the learning rate update after n updates
    factor: float
        the factor by which to reduce the learning rate.
=cut
use Mouse;
extends 'AI::MXNet::LRScheduler';

has 'step'            => (is => 'ro', isa => 'Int', required => 1);
has 'factor'          => (is => 'ro', isa => 'Num', default  => 1);
has 'count'           => (is => 'rw', isa => 'Int', default  => 1);
has 'stop_factor_lr'  => (is => 'ro', isa => 'Num', default  => 1e-8);

sub BUILD
{
    my $self = shift;
    confess("Schedule step must be greater or equal than 1")
        if $self->step < 1;
    confess("Factor must be no more than 1 to make lr reduce")
        if $self->factor > 1;
}

method call(Int $num_update)
{
    # NOTE: use while rather than if  (for continuing training via load_epoch)
    while($num_update > $self->count + $self->step)
    {
        $self->count($self->count + $self->step);
        $self->base_lr($self->base_lr * $self->factor);
        if($self->base_lr < $self->stop_factor_lr)
        {
            $self->base_lr($self->stop_factor_lr);
            AI::MXNet::Logging->info(
                "Update[%d]: now learning rate arrived at %0.5e, will not "
                ."change in the future", $num_update, $self->base_lr
            );
        }
        else
        {
            AI::MXNet::Logging->info(
                "Update[%d]: Changed learning rate to %0.5e",
                $num_update, $self->base_lr
            );
        }
    }
    return $self->base_lr;
}

package AI::MXNet::MultiFactorScheduler;

=head1 NAME

    AI::MXNet::MultiFactorScheduler - Reduces the learning rate by an array ref of factors.

=head1 DESCRIPTION

    Reduces a learning rate in factor at steps specified in an array ref.
    Assume the weight has been updated by n times, then the learning rate will
    be base_lr * factor^(sum((step/n)<=1)) # step is an array.

    Parameters
    ----------
    step: array ref of int
        schedule learning rate after n updates
    factor: float
        the factor for reducing the learning rate
=cut

use Mouse;
extends 'AI::MXNet::LRScheduler';
has 'step'            => (is => 'ro', isa => 'ArrayRef[Int]', required => 1);
has 'factor'          => (is => 'ro', isa => 'Num', default  => 1);
has 'cur_step_ind'    => (is => 'rw', isa => 'Int', default  => 0);
has 'count'           => (is => 'rw', isa => 'Int', default  => 0);

sub BUILD
{
    my $self = shift;
    confess("step array must have at least one member")
        unless @{ $self->step } >=1 ;
    for (my $i = 0; $i < @{ $self->step }; $i++)
    {
        confess("Schedule step must be an increasing integer list")
            if($i and $self->step->[$i] <= $self->step->[$i-1]);
        confess("Schedule step must be greater or equal than 1")
            if $self->step->[$i] < 1;
    }
    confess("Factor must be no more than 1 to make lr reduce")
        if $self->factor > 1;
}

method call(Int $num_update)
{
    # NOTE: use while rather than if  (for continuing training via load_epoch)
    while($self->cur_step_ind < @{ $self->step })
    {
        if($num_update > $self->step->[$self->cur_step_ind])
        {
            $self->count($self->step->[$self->cur_step_ind]);
            $self->cur_step_ind($self->cur_step_ind + 1);
            $self->base_lr($self->base_lr * $self->factor);
            AI::MXNet::Logging->info(
                "Update[%d]: Changed learning rate to %0.5e",
                $num_update, $self->base_lr
            );
        }
        else
        {
            return $self->base_lr;
        }
    }
    return $self->base_lr;
}

1;



( run in 1.823 second using v1.01-cache-2.11-cpan-39bf76dae61 )