AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Module/Base.pm  view on Meta::CPAN


    And also the following richer information after binded:

    - state information
        - binded: bool, indicating whether the memory buffers needed for computation
        has been allocated.
        - for_training: whether the module is binded for training (if binded).
        - params_initialized: bool, indicating whether the parameters of this modules
        has been initialized.
        - optimizer_initialized: bool, indicating whether an optimizer is defined
        and initialized.
        - inputs_need_grad: bool, indicating whether gradients with respect to the
        input data is needed. Might be useful when implementing composition of modules.

    - input/output information
        - data_shapes: am array ref of [name, shape]. In theory, since the memory is allocated,
        we could directly provide the data arrays. But in the case of data parallelization,
        the data arrays might not be of the same shape as viewed from the external world.
        - label_shapes: an array ref of [name, shape]. This might be [] if the module does
        not need labels (e.g. it does not contains a loss function at the top), or a module
        is not binded for training.
        - output_shapes: an array ref of [name, shape] for outputs of the module.

    - parameters (for modules with parameters)
        - get_params(): return an array ($arg_params, $aux_params). Each of those
        is a hash ref of name to NDArray mapping. Those NDArrays always on
        CPU. The actual parameters used for computing might be on other devices (GPUs),
        this function will retrieve (a copy of) the latest parameters. Therefore, modifying
        - get_params($arg_params, $aux_params): assign parameters to the devices
        doing the computation.
        - init_params(...): a more flexible interface to assign or initialize the parameters.

    - setup
        - bind(): prepare environment for computation.
        - init_optimizer(): install optimizer for parameter updating.

    - computation
        - forward(data_batch): forward operation.
        - backward(out_grads=): backward operation.
        - update(): update parameters according to installed optimizer.
        - get_outputs(): get outputs of the previous forward operation.
        - get_input_grads(): get the gradients with respect to the inputs computed
        in the previous backward operation.
        - update_metric(metric, labels): update performance metric for the previous forward
        computed results.

    - other properties (mostly for backward compatability)
        - symbol: the underlying symbolic graph for this module (if any)
        This property is not necessarily constant. For example, for AI::MXNet::Module::Bucketing,
        this property is simply the *current* symbol being used. For other modules,
        this value might not be well defined.

    When those intermediate-level API are implemented properly, the following
    high-level API will be automatically available for a module:

        - fit: train the module parameters on a data set
        - predict: run prediction on a data set and collect outputs
        - score: run prediction on a data set and evaluate performance
=cut

has 'logger'            => (is => 'rw', default => sub { AI::MXNet::Logging->get_logger });
has '_symbol'           => (is => 'rw', init_arg => 'symbol', isa => 'AI::MXNet::Symbol');
has [
    qw/binded for_training inputs_need_grad
    params_initialized optimizer_initialized/
]                       => (is => 'rw', isa => 'Bool', init_arg => undef, default => 0);

################################################################################
# High Level API
################################################################################

=head2 forward_backward

    A convenient function that calls both forward and backward.
=cut

method forward_backward(AI::MXNet::DataBatch $data_batch)
{
    $self->forward($data_batch, is_train => 1);
    $self->backward();
}

=head2 score

    Run prediction on eval_data and evaluate the performance according to
    eval_metric.

    Parameters
    ----------
    $eval_data   : AI::MXNet::DataIter
    $eval_metric : AI::MXNet::EvalMetric
    :$num_batch= : Maybe[Int]
        Number of batches to run. Default is undef, indicating run until the AI::MXNet::DataIter
        finishes.
    :$batch_end_callback= : Maybe[Callback]
        Could also be a array ref of functions.
    :$reset=1 : Bool
        Default 1, indicating whether we should reset $eval_data before starting
        evaluating.
    $epoch=0 : Int
        Default is 0. For compatibility, this will be passed to callbacks (if any). During
        training, this will correspond to the training epoch number.
=cut

method score(
    AI::MXNet::DataIter                 $eval_data,
    EvalMetric                          $eval_metric,
    Maybe[Int]                          :$num_batch=,
    Maybe[Callback]|ArrayRef[Callback]  :$batch_end_callback=,
    Maybe[Callback]|ArrayRef[Callback]  :$score_end_callback=,
    Bool                                :$reset=1,
    Int                                 :$epoch=0
)
{
    assert($self->binded and $self->params_initialized);
    $eval_data->reset if $reset;
    if(not blessed $eval_metric or not $eval_metric->isa('AI::MXNet::EvalMetric'))
    {
        $eval_metric = AI::MXNet::Metric->create($eval_metric);
    }

lib/AI/MXNet/Module/Base.pm  view on Meta::CPAN

        force_init    => $force_init
    );
    $self->init_optimizer(
        kvstore          => $kvstore,
        optimizer        => $optimizer,
        optimizer_params => $optimizer_params
    );

    if(not defined $validation_metric)
    {
        $validation_metric = $eval_metric;
    }
    $eval_metric = AI::MXNet::Metric->create($eval_metric)
        unless blessed $eval_metric;

    ################################################################################
    # training loop
    ################################################################################
    for my $epoch ($begin_epoch..$num_epoch-1)
    {
        my $tic = time;
        $eval_metric->reset;
        my $nbatch = 0;
        my $end_of_batch = 0;
        my $next_data_batch = <$train_data>;
        while(not $end_of_batch)
        {
            my $data_batch = $next_data_batch;
            $monitor->tic if $monitor;
            $self->forward_backward($data_batch);
            $self->update;
            $next_data_batch = <$train_data>;
            if(defined $next_data_batch)
            {
                $self->prepare($next_data_batch);
            }
            else
            {
                $end_of_batch = 1;
            }
            $self->update_metric($eval_metric, $data_batch->label);
            $monitor->toc_print if $monitor;
            if(defined $batch_end_callback)
            {
                my $batch_end_params = AI::MXNet::BatchEndParam->new(
                    epoch       => $epoch,
                    nbatch      => $nbatch,
                    eval_metric => $eval_metric
                );
                for my $callback (@{ _as_list($batch_end_callback) })
                {
                    &{$callback}($batch_end_params);
                }
            }
            $nbatch++;
        }
        # one epoch of training is finished
        my $name_value = $eval_metric->get_name_value;
        while(my ($name, $val) = each %{ $name_value })
        {
            $self->logger->info('Epoch[%d] Train-%s=%f', $epoch, $name, $val);
        }
        my $toc = time;
        $self->logger->info('Epoch[%d] Time cost=%.3f', $epoch, ($toc-$tic));

        # sync aux params across devices
        my ($arg_params, $aux_params) = $self->get_params;
        $self->set_params($arg_params, $aux_params);

        if($epoch_end_callback)
        {
            for my $callback (@{ _as_list($epoch_end_callback) })
            {
                &{$callback}($epoch, $self->get_symbol, $arg_params, $aux_params);
            }
        }
        #----------------------------------------
        # evaluation on validation set
        if(defined $eval_data)
        {
            my $res = $self->score(
                $eval_data,
                $validation_metric,
                score_end_callback => $eval_end_callback,
                batch_end_callback => $eval_batch_end_callback,
                epoch              => $epoch
            );
            #TODO: pull this into default
            while(my ($name, $val) = each %{ $res })
            {
                $self->logger->info('Epoch[%d] Validation-%s=%f', $epoch, $name, $val);
            }
        }
        # end of 1 epoch, reset the data-iter for another epoch
        $train_data->reset;
    }
}

################################################################################
# Symbol information
################################################################################

=head2 get_symbol

    The symbol used by this module.
=cut
method get_symbol() { $self->symbol }

=head2 data_names

    An array ref of names for data required by this module.
=cut
method data_names() { confess("NotImplemented") }

=head2 output_names

    An array ref of names for the outputs of this module.
=cut
method output_names() { confess("NotImplemented") }

################################################################################
# Input/Output information
################################################################################

=head2 data_shapes

    An array ref of AI::MXNet::DataDesc objects specifying the data inputs to this module.
=cut
method data_shapes() { confess("NotImplemented") }

=head2 label_shapes

    A array ref of AI::MXNet::DataDesc objects specifying the label inputs to this module.
    If this module does not accept labels -- either it is a module without a loss
    function, or it is not binded for training, then this should return an empty
    array ref.
=cut
method label_shapes() { confess("NotImplemented") }

=head2 output_shapes

    An array ref of (name, shape) array refs specifying the outputs of this module.
=cut
method output_shapes() { confess("NotImplemented") }

################################################################################
# Parameters of a module
################################################################################

=head2 get_params



( run in 0.639 second using v1.01-cache-2.11-cpan-39bf76dae61 )