AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Module/Base.pm  view on Meta::CPAN

package AI::MXNet::BatchEndParam;
use Mouse;
use AI::MXNet::Function::Parameters;
has [qw/epoch nbatch/] => (is => 'rw', isa => 'Int');
has 'eval_metric'      => (is => 'rw', isa => 'AI::MXNet::EvalMetric');

package AI::MXNet::Module::Base;
use Mouse;
use AI::MXNet::Base;
use Time::HiRes qw(time);

=head1 NAME

    AI::MXNet::Module::Base - Base class for AI::MXNet::Module and AI::MXNet::Module::Bucketing
=cut

func _as_list($obj)
{
    return [$obj] if ((ref($obj)//'') ne 'ARRAY');
    return $obj;
}

# Check that all input names are in symbol's argument
method _check_input_names(
    AI::MXNet::Symbol $symbol,
    ArrayRef[Str]     $names,
    Str               $typename,
    Bool              $throw
)
{
    my @candidates;
    my %args = map {
        push @candidates, $_ if not /_(?:weight|bias|gamma|beta)$/;
        $_ => 1
    } @{ $symbol->list_arguments };
    for my $name (@$names)
    {
        my $msg;
        if(not exists $args{$name} and $name ne 'softmax_label')
        {
            $msg = sprintf("\033[91mYou created Module with Module(..., %s_names=%s) but "
                ."input with name '%s' is not found in symbol.list_arguments(). "
                ."Did you mean one of:\n\t%s\033[0m",
                $typename, "@$names", $name, join("\n\t", @candidates)
            );
            if($throw)
            {
                confess($msg);
            }
            else
            {
                AI::MXNet::Logging->warning($msg);
            }
        }
    }
}

# Check that input names matches input data descriptors
method _check_names_match(
    ArrayRef[Str]                  $data_names,
    ArrayRef[NameShapeOrDataDesc]  $data_shapes,
    Str                            $name,
    Bool                           $throw
)
{
    return if (not @$data_shapes and @$data_names == 1 and  $data_names->[0] eq 'softmax_label');
    my @actual = map { @{$_}[0] } @{ $data_shapes };
    if("@$data_names" ne "@actual")
    {
        my $msg = sprintf(
            "Data provided by %s_shapes don't match names specified by %s_names (%s vs. %s)",
            $name, $name, "@$data_shapes", "@$data_names"
        );
        if($throw)
        {
            confess($msg);
        }
        else
        {
            AI::MXNet::Logging->warning($msg);
        }
    }
}

method _parse_data_desc(
    ArrayRef[Str]                                  $data_names,
    Maybe[ArrayRef[Str]]                           $label_names,
    ArrayRef[NameShapeOrDataDesc]                  $data_shapes,
    Maybe[ArrayRef[NameShapeOrDataDesc]]           $label_shapes
)
{
    $data_shapes = [map { blessed $_ ? $_ : AI::MXNet::DataDesc->new(@$_) } @$data_shapes];
    $self->_check_names_match($data_names, $data_shapes, 'data', 1);
    if($label_shapes)
    {
        $label_shapes = [map { blessed $_ ? $_ : AI::MXNet::DataDesc->new(@$_) } @$label_shapes];
        $self->_check_names_match($label_names, $label_shapes, 'label', 0);
    }
    else
    {
        $self->_check_names_match($label_names, [], 'label', 0);
    }
    return ($data_shapes, $label_shapes);
}

=head1 DESCRIPTION

    The base class of a modules. A module represents a computation component. The design
    purpose of a module is that it abstract a computation "machine", that one can run forward,
    backward, update parameters, etc. We aim to make the APIs easy to use, especially in the
    case when we need to use imperative API to work with multiple modules (e.g. stochastic
    depth network).

    A module has several states:

        - Initial state. Memory is not allocated yet, not ready for computation yet.
        - Binded. Shapes for inputs, outputs, and parameters are all known, memory allocated,
        ready for computation.
        - Parameter initialized. For modules with parameters, doing computation before initializing
        the parameters might result in undefined outputs.
        - Optimizer installed. An optimizer can be installed to a module. After this, the parameters
        of the module can be updated according to the optimizer after gradients are computed
        (forward-backward).

    In order for a module to interact with others, a module should be able to report the
    following information in its raw stage (before binded)

        - data_names: array ref of string indicating the names of required data.
        - output_names: array ref of string indicating the names of required outputs.

    And also the following richer information after binded:

    - state information
        - binded: bool, indicating whether the memory buffers needed for computation

lib/AI/MXNet/Module/Base.pm  view on Meta::CPAN

################################################################################
# Input/Output information
################################################################################

=head2 data_shapes

    An array ref of AI::MXNet::DataDesc objects specifying the data inputs to this module.
=cut
method data_shapes() { confess("NotImplemented") }

=head2 label_shapes

    A array ref of AI::MXNet::DataDesc objects specifying the label inputs to this module.
    If this module does not accept labels -- either it is a module without a loss
    function, or it is not binded for training, then this should return an empty
    array ref.
=cut
method label_shapes() { confess("NotImplemented") }

=head2 output_shapes

    An array ref of (name, shape) array refs specifying the outputs of this module.
=cut
method output_shapes() { confess("NotImplemented") }

################################################################################
# Parameters of a module
################################################################################

=head2 get_params

    The parameters, these are potentially a copies of the the actual parameters used
    to do computation on the device.

    Returns
    -------
    ($arg_params, $aux_params), a pair of hash refs of name to value mapping.
=cut

method get_params() { confess("NotImplemented") }

=head2 init_params

    Initialize the parameters and auxiliary states.

    Parameters
    ----------
    :$initializer : Maybe[AI::MXNet::Initializer]
        Called to initialize parameters if needed.
    :$arg_params= : Maybe[HashRef[AI::MXNet::NDArray]]
        If not undef, should be a hash ref of existing arg_params.
    :$aux_params : Maybe[HashRef[AI::MXNet::NDArray]]
        If not undef, should be a hash ref of existing aux_params.
    :$allow_missing=0 : Bool
        If true, params could contain missing values, and the initializer will be
        called to fill those missing params.
    :$force_init=0 : Bool
        If true, will force re-initialize even if already initialized.
    :$allow_extra=0 : Boolean, optional
        Whether allow extra parameters that are not needed by symbol.
        If this is True, no error will be thrown when arg_params or aux_params
        contain extra parameters that is not needed by the executor.
=cut

method init_params(
    Maybe[AI::MXNet::Initializer]      :$initializer=AI::MXNet::Initializer->Uniform(0.01),
    Maybe[HashRef[AI::MXNet::NDArray]] :$arg_params=,
    Maybe[HashRef[AI::MXNet::NDArray]] :$aux_params=,
    Bool                               :$allow_missing=0,
    Bool                               :$force_init=0,
    Bool                               :$allow_extra=0
)
{
    confess("NotImplemented");
}

=head2 set_params

    Assign parameter and aux state values.

    Parameters
    ----------
    $arg_params= : Maybe[HashRef[AI::MXNet::NDArray]]
        Hash ref of name to value (NDArray) mapping.
    $aux_params= : Maybe[HashRef[AI::MXNet::NDArray]]
        Hash Ref of name to value (`NDArray`) mapping.
    :$allow_missing=0 : Bool
        If true, params could contain missing values, and the initializer will be
        called to fill those missing params.
    :$force_init=0 : Bool
        If true, will force re-initialize even if already initialized.
    :$allow_extra=0 : Bool
        Whether allow extra parameters that are not needed by symbol.
        If this is True, no error will be thrown when arg_params or aux_params
        contain extra parameters that is not needed by the executor.
=cut

method set_params(
    Maybe[HashRef[AI::MXNet::NDArray]]  $arg_params=,
    Maybe[HashRef[AI::MXNet::NDArray]]  $aux_params=,
    Bool                               :$allow_missing=0,
    Bool                               :$force_init=0,
    Bool                               :$allow_extra=0
)
{
    $self->init_params(
        initializer   => undef,
        arg_params    => $arg_params,
        aux_params    => $aux_params,
        allow_missing => $allow_missing,
        force_init    => $force_init,
        allow_extra   => $allow_extra
    );
}

=head2 save_params

    Save model parameters to file.

    Parameters
    ----------
    $fname : str
        Path to output param file.
    $arg_params= : Maybe[HashRef[AI::MXNet::NDArray]]
    $aux_params= : Maybe[HashRef[AI::MXNet::NDArray]]
=cut

method save_params(
    Str $fname,
    Maybe[HashRef[AI::MXNet::NDArray]] $arg_params=,
    Maybe[HashRef[AI::MXNet::NDArray]] $aux_params=
)
{
    ($arg_params, $aux_params) = $self->get_params
        unless (defined $arg_params and defined $aux_params);
    my %save_dict;
    while(my ($k, $v) = each %{ $arg_params })
    {
        $save_dict{"arg:$k"} = $v->as_in_context(AI::MXNet::Context->cpu);
    }
    while(my ($k, $v) = each %{ $aux_params })
    {
        $save_dict{"aux:$k"} = $v->as_in_context(AI::MXNet::Context->cpu);
    }
    AI::MXNet::NDArray->save($fname, \%save_dict);
}

=head2 load_params

    Load model parameters from file.

    Parameters
    ----------
    $fname : str



( run in 1.896 second using v1.01-cache-2.11-cpan-39bf76dae61 )