AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Callback.pm  view on Meta::CPAN

        if(($iter_no + 1) % $period == 0)
        {
            $mod->save_checkpoint($prefix, $iter_no + 1, $save_optimizer_states);
        }
    }
}

=head2 log_train_metric

    Callback to log the training evaluation result every period.

    Parameters
    ----------
    $period : Int
        The number of batches after which to log the training evaluation metric.
    $auto_reset : Bool
        Whether to reset the metric after the logging.

    Returns
    -------
    $callback : sub ref
        The callback function that can be passed as iter_epoch_callback to fit.
=cut

method log_train_metric(Int $period, Int $auto_reset=0)
{
    return sub {
        my ($param) = @_;
        if($param->nbatch % $period == 0 and defined $param->eval_metric)
        {
            my $name_value = $param->eval_metric->get_name_value;
            while(my ($name, $value) = each %{ $name_value })
            {
                AI::MXNet::Logging->info(
                    "Iter[%d] Batch[%d] Train-%s=%f",
                    $param->epoch, $param->nbatch, $name, $value
                );
            }
            $param->eval_metric->reset if $auto_reset;
        }
    }
}

package AI::MXNet::Speedometer;
use Mouse;
use Time::HiRes qw/time/;
extends 'AI::MXNet::Callback';

=head1 NAME

    AI::MXNet::Speedometer - A callback that logs training speed 
=cut

=head1 DESCRIPTION

    Calculate and log training speed periodically.

    Parameters
    ----------
    batch_size: int
        batch_size of data
    frequent: int
        How many batches between calculations.
        Defaults to calculating & logging every 50 batches.
    auto_reset: Bool
        Reset the metric after each log, defaults to true.
=cut

has 'batch_size' => (is => 'ro', isa => 'Int', required => 1);
has 'frequent'   => (is => 'ro', isa => 'Int', default  => 50);
has 'init'       => (is => 'rw', isa => 'Int', default  => 0);
has 'tic'        => (is => 'rw', isa => 'Num', default  => 0);
has 'last_count' => (is => 'rw', isa => 'Int', default  => 0);
has 'auto_reset' => (is => 'ro', isa => 'Bool', default  => 1);

method call(AI::MXNet::BatchEndParam $param)
{
    my $count = $param->nbatch;
    if($self->last_count > $count)
    {
        $self->init(0);
    }
    $self->last_count($count);

    if($self->init)
    {
        if(($count % $self->frequent) == 0)
        {
            my $speed = $self->frequent * $self->batch_size / (time - $self->tic);
            if(defined $param->eval_metric)
            {
                my $name_value = $param->eval_metric->get_name_value;
                $param->eval_metric->reset if $self->auto_reset;
                while(my ($name, $value) = each %{ $name_value })
                {
                    AI::MXNet::Logging->info(
                        "Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec\tTrain-%s=%f",
                        $param->epoch, $count, $speed, $name, $value
                    );
                }
            }
            else
            {
                AI::MXNet::Logging->info(
                    "Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec",
                    $param->epoch, $count, $speed
                );
            }
            $self->tic(time);
        }
    }
    else
    {
        $self->init(1);
        $self->tic(time);
    }
}

*slice = \&call;

package AI::MXNet::ProgressBar;



( run in 1.421 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )