AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Module/Base.pm  view on Meta::CPAN

package AI::MXNet::BatchEndParam;
use Mouse;
use AI::MXNet::Function::Parameters;
has [qw/epoch nbatch/] => (is => 'rw', isa => 'Int');
has 'eval_metric'      => (is => 'rw', isa => 'AI::MXNet::EvalMetric');

package AI::MXNet::Module::Base;
use Mouse;
use AI::MXNet::Base;
use Time::HiRes qw(time);

=head1 NAME

    AI::MXNet::Module::Base - Base class for AI::MXNet::Module and AI::MXNet::Module::Bucketing
=cut

func _as_list($obj)
{
    return [$obj] if ((ref($obj)//'') ne 'ARRAY');
    return $obj;
}

# Check that all input names are in symbol's argument
method _check_input_names(
    AI::MXNet::Symbol $symbol,
    ArrayRef[Str]     $names,
    Str               $typename,
    Bool              $throw
)
{
    my @candidates;
    my %args = map {
        push @candidates, $_ if not /_(?:weight|bias|gamma|beta)$/;
        $_ => 1
    } @{ $symbol->list_arguments };
    for my $name (@$names)
    {
        my $msg;
        if(not exists $args{$name} and $name ne 'softmax_label')
        {
            $msg = sprintf("\033[91mYou created Module with Module(..., %s_names=%s) but "
                ."input with name '%s' is not found in symbol.list_arguments(). "
                ."Did you mean one of:\n\t%s\033[0m",
                $typename, "@$names", $name, join("\n\t", @candidates)
            );
            if($throw)
            {
                confess($msg);
            }
            else
            {
                AI::MXNet::Logging->warning($msg);
            }
        }
    }
}

# Check that input names matches input data descriptors
method _check_names_match(
    ArrayRef[Str]                  $data_names,
    ArrayRef[NameShapeOrDataDesc]  $data_shapes,
    Str                            $name,
    Bool                           $throw
)
{
    return if (not @$data_shapes and @$data_names == 1 and  $data_names->[0] eq 'softmax_label');
    my @actual = map { @{$_}[0] } @{ $data_shapes };
    if("@$data_names" ne "@actual")
    {
        my $msg = sprintf(
            "Data provided by %s_shapes don't match names specified by %s_names (%s vs. %s)",
            $name, $name, "@$data_shapes", "@$data_names"
        );
        if($throw)
        {
            confess($msg);
        }
        else
        {
            AI::MXNet::Logging->warning($msg);
        }
    }
}

method _parse_data_desc(
    ArrayRef[Str]                                  $data_names,
    Maybe[ArrayRef[Str]]                           $label_names,
    ArrayRef[NameShapeOrDataDesc]                  $data_shapes,
    Maybe[ArrayRef[NameShapeOrDataDesc]]           $label_shapes
)
{
    $data_shapes = [map { blessed $_ ? $_ : AI::MXNet::DataDesc->new(@$_) } @$data_shapes];
    $self->_check_names_match($data_names, $data_shapes, 'data', 1);
    if($label_shapes)
    {
        $label_shapes = [map { blessed $_ ? $_ : AI::MXNet::DataDesc->new(@$_) } @$label_shapes];
        $self->_check_names_match($label_names, $label_shapes, 'label', 0);
    }
    else
    {
        $self->_check_names_match($label_names, [], 'label', 0);
    }
    return ($data_shapes, $label_shapes);
}

lib/AI/MXNet/Module/Base.pm  view on Meta::CPAN

{
    $self->init_params(
        initializer   => undef,
        arg_params    => $arg_params,
        aux_params    => $aux_params,
        allow_missing => $allow_missing,
        force_init    => $force_init,
        allow_extra   => $allow_extra
    );
}

=head2 save_params

    Save model parameters to file.

    Parameters
    ----------
    $fname : str
        Path to output param file.
    $arg_params= : Maybe[HashRef[AI::MXNet::NDArray]]
    $aux_params= : Maybe[HashRef[AI::MXNet::NDArray]]
=cut

method save_params(
    Str $fname,
    Maybe[HashRef[AI::MXNet::NDArray]] $arg_params=,
    Maybe[HashRef[AI::MXNet::NDArray]] $aux_params=
)
{
    ($arg_params, $aux_params) = $self->get_params
        unless (defined $arg_params and defined $aux_params);
    my %save_dict;
    while(my ($k, $v) = each %{ $arg_params })
    {
        $save_dict{"arg:$k"} = $v->as_in_context(AI::MXNet::Context->cpu);
    }
    while(my ($k, $v) = each %{ $aux_params })
    {
        $save_dict{"aux:$k"} = $v->as_in_context(AI::MXNet::Context->cpu);
    }
    AI::MXNet::NDArray->save($fname, \%save_dict);
}

=head2 load_params

    Load model parameters from file.

    Parameters
    ----------
    $fname : str
        Path to input param file.
=cut

method load_params(Str $fname)
{
    my %save_dict = %{ AI::MXNet::NDArray->load($fname) };
    my %arg_params;
    my %aux_params;
    while(my ($k, $v) = each %save_dict)
    {
        my ($arg_type, $name) = split(/:/, $k, 2);
        if($arg_type eq 'arg')
        {
            $arg_params{ $name } = $v;
        }
        elsif($arg_type eq 'aux')
        {
            $aux_params{ $name } = $v;
        }
        else
        {
            confess("Invalid param file $fname");
        }
    }
    $self->set_params(\%arg_params, \%aux_params);
}

=head2 get_states

    The states from all devices

    Parameters
    ----------
    $merge_multi_context=1 : Bool
        Default is true (1). In the case when data-parallelism is used, the states
        will be collected from multiple devices. A true value indicate that we
        should merge the collected results so that they look like from a single
        executor.

    Returns
    -------
    If $merge_multi_context is 1, it is like [$out1, $out2]. Otherwise, it
    is like [[$out1_dev1, $out1_dev2], [$out2_dev1, $out2_dev2]]. All the output
    elements are AI::MXNet::NDArray.
=cut

method get_states(Bool $merge_multi_context=1)
{
    assert($self->binded and $self->params_initialized);
    assert(not $merge_multi_context);
    return [];
}

=head2 set_states

    Set value for states. You can specify either $states or $value, not both.

    Parameters
    ----------
    $states= : Maybe[ArrayRef[ArrayRef[AI::MXNet::NDArray]]]
        source states arrays formatted like [[$state1_dev1, $state1_dev2],
            [$state2_dev1, $state2_dev2]].
    $value= : Maybe[Num]
        a single scalar value for all state arrays.
=cut

method set_states(Maybe[ArrayRef[ArrayRef[AI::MXNet::NDArray]]] $states=, Maybe[Num] $value=)
{
    assert($self->binded and $self->params_initialized);
    assert(not $states and not $value);
}


=head2 install_monitor

    Install monitor on all executors

lib/AI/MXNet/Module/Base.pm  view on Meta::CPAN

    :$inputs_need_grad=0 : Bool
        Default is 0. Whether the gradients to the input data need to be computed.
        Typically this is not needed. But this might be needed when implementing composition
        of modules.
    :$force_rebind=0 : Bool
        Default is 0. This function does nothing if the executors are already
        binded. But with this as 1, the executors will be forced to rebind.
    :$shared_module= : A subclass of AI::MXNet::Module::Base
        Default is undef. This is used in bucketing. When not undef, the shared module
        essentially corresponds to a different bucket -- a module with different symbol
        but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
    :$grad_req='write' : Str|ArrayRef[Str]|HashRef[Str]
        Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
        (defaults to 'write').
        Can be specified globally (str) or for each argument (array ref, hash ref).
=cut

method bind(
    ArrayRef[AI::MXNet::DataDesc]         $data_shapes,
    Maybe[ArrayRef[AI::MXNet::DataDesc]] :$label_shapes=,
    Bool                                 :$for_training=1,
    Bool                                 :$inputs_need_grad=0,
    Bool                                 :$force_rebind=0,
    Maybe[AI::MXNet::BaseModule]         :$shared_module=,
    Str|ArrayRef[Str]|HashRef[Str]       :$grad_req='write'
)
{
    confess("NotImplemented")
}

=head2 init_optimizer

    Install and initialize optimizers.

    Parameters
    ----------
    :$kvstore='local' : str or KVStore
    :$optimizer='sgd' : str or Optimizer
    :$optimizer_params={ learning_rate => 0.01 } : hash ref
    :$force_init=0 : Bool
=cut

method init_optimizer(
    Str        :$kvstore='local',
    Optimizer  :$optimizer='sgd',
    HashRef    :$optimizer_params={ learning_rate => 0.01 },
    Bool       :$force_init=0
)
{
    confess("NotImplemented")
}

################################################################################
# misc
################################################################################

=head2 symbol

    The symbol associated with this module.

    Except for AI::MXNet::Module, for other types of modules (e.g. AI::MXNet::Module::Bucketing), this
    property might not be a constant throughout its life time. Some modules might
    not even be associated with any symbols.
=cut

method symbol()
{
    return $self->_symbol;
}

1;



( run in 1.623 second using v1.01-cache-2.11-cpan-df04353d9ac )