AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Module.pm  view on Meta::CPAN

    }, $param_arrays);
}

func _update_params_on_kvstore(
    ArrayRef[AI::MXNet::NDArray]|ArrayRef[ArrayRef[AI::MXNet::NDArray]] $param_arrays,
    ArrayRef[AI::MXNet::NDArray]|ArrayRef[ArrayRef[AI::MXNet::NDArray]] $grad_arrays,
    AI::MXNet::KVStore           $kvstore,
    ArrayRef[Str]                $param_names
)
{
    enumerate(sub{
        my ($index, $arg_list, $grad_list) = @_;
        if(ref $grad_list eq 'ARRAY' and not defined $grad_list->[0])
        {
            return;
        }
        my $name = $param_names->[$index];
        # push gradient, priority is negative index
        $kvstore->push($name, $grad_list, priority => -$index);
        # pull back the weights
        $kvstore->pull($name, out => $arg_list, priority  => -$index);
    }, $param_arrays, $grad_arrays);
}

func _update_params(
    ArrayRef[ArrayRef[AI::MXNet::NDArray]] $param_arrays,
    ArrayRef[ArrayRef[AI::MXNet::NDArray]] $grad_arrays,
    AI::MXNet::Updater                     $updater,
    Int                                    $num_device,
    Maybe[AI::MXNet::KVStore]              $kvstore=,
    Maybe[ArrayRef[Str]]                   $param_names=
)
{
    enumerate(sub{
        my ($index, $arg_list, $grad_list) = @_;
        if(not defined $grad_list->[0])
        {
            return;
        }
        if($kvstore)
        {
            my $name = $param_names->[$index];
            # push gradient, priority is negative index
            $kvstore->push($name, $grad_list, priority => -$index);
            # pull back the sum gradients, to the same locations.
            $kvstore->pull($name, out => $grad_list, priority => -$index);
        }
        enumerate(sub {
            my ($k, $w, $g) = @_;
            # faked an index here, to make optimizer create diff
            # state for the same index but on diff devs, TODO(mli)
            # use a better solution later
            &{$updater}($index*$num_device+$k, $g, $w);
        }, $arg_list, $grad_list);
    }, $param_arrays, $grad_arrays);
}

method load_checkpoint(Str $prefix, Int $epoch)
{
    my $symbol = AI::MXNet::Symbol->load("$prefix-symbol.json");
    my %save_dict = %{ AI::MXNet::NDArray->load(sprintf('%s-%04d.params', $prefix, $epoch)) };
    my %arg_params;
    my %aux_params;
    while(my ($k, $v) = each %save_dict)
    {
        my ($tp, $name) = split(/:/, $k, 2);
        if($tp eq 'arg')
        {
            $arg_params{$name} = $v;
        }
        if($tp eq 'aux')
        {
            $aux_params{$name} = $v;
        }
    }
    return ($symbol, \%arg_params, \%aux_params);
}

=head1 NAME

    AI::MXNet::Module - FeedForward interface of MXNet.
    See AI::MXNet::Module::Base for the details.
=cut

extends 'AI::MXNet::Module::Base';

has '_symbol'           => (is => 'ro', init_arg => 'symbol', isa => 'AI::MXNet::Symbol', required => 1);
has '_data_names'       => (is => 'ro', init_arg => 'data_names', isa => 'ArrayRef[Str]');
has '_label_names'      => (is => 'ro', init_arg => 'label_names', isa => 'Maybe[ArrayRef[Str]]');
has 'work_load_list'    => (is => 'rw', isa => 'Maybe[ArrayRef[Int]]');
has 'fixed_param_names' => (is => 'rw', isa => 'Maybe[ArrayRef[Str]]');
has 'state_names'       => (is => 'rw', isa => 'Maybe[ArrayRef[Str]]');
has 'logger'            => (is => 'ro', default => sub { AI::MXNet::Logging->get_logger });
has '_p'                => (is => 'rw', init_arg => undef);
has 'context'           => (
    is => 'ro', 
    isa => 'AI::MXNet::Context|ArrayRef[AI::MXNet::Context]',
    default => sub { AI::MXNet::Context->cpu }
);

around BUILDARGS => sub {
    my $orig  = shift;
    my $class = shift;
    if(@_%2)
    {
        my $symbol = shift;
        return $class->$orig(symbol => $symbol, @_);
    }
    return $class->$orig(@_);
};

sub BUILD
{
    my $self = shift;
    $self->_p(AI::MXNet::Module::Private->new);
    my $context = $self->context;
    if(blessed $context)
    {
        $context = [$context];
    }
    $self->_p->_context($context);

lib/AI/MXNet/Module.pm  view on Meta::CPAN

    $self->_p->_state_names(\@state_names);
    $self->_p->_aux_names($self->_symbol->list_auxiliary_states);
    $self->_p->_data_names(\@data_names);
    $self->_p->_label_names(\@label_names);
    $self->_p->_output_names($self->_symbol->list_outputs);
    $self->_p->_params_dirty(0);
    $self->_check_input_names($self->_symbol, $self->_p->_data_names, "data", 1);
    $self->_check_input_names($self->_symbol, $self->_p->_label_names, "label", 0);
    $self->_check_input_names($self->_symbol, $self->_p->_state_names, "state", 1);
    $self->_check_input_names($self->_symbol, $self->_p->_fixed_param_names, "fixed_param", 1);
}

method Module(@args) { return @args ?  __PACKAGE__->new(@args) : __PACKAGE__ }
method BucketingModule(@args) { return AI::MXNet::Module::Bucketing->new(@args) }

=head2 load

        Create a model from previously saved checkpoint.

        Parameters
        ----------
        prefix : str
            path prefix of saved model files. You should have
            "prefix-symbol.json", "prefix-xxxx.params", and
            optionally "prefix-xxxx.states", where xxxx is the
            epoch number.
        epoch : int
            epoch to load.
        load_optimizer_states : bool
            whether to load optimizer states. Checkpoint needs
            to have been made with save_optimizer_states=True.
        data_names : array ref of str
            Default is ['data'] for a typical model used in image classification.
        label_names : array ref of str
            Default is ['softmax_label'] for a typical model used in image
            classification.
        logger : Logger
            Default is AI::MXNet::Logging.
        context : Context or list of Context
            Default is cpu(0).
        work_load_list : array ref of number
            Default is undef, indicating an uniform workload.
        fixed_param_names: array ref of str
            Default is undef, indicating no network parameters are fixed.
=cut

method load(
    Str $prefix,
    Int $epoch,
    Bool $load_optimizer_states=0,
    %kwargs
)
{
    my ($sym, $args, $auxs) = __PACKAGE__->load_checkpoint($prefix, $epoch);
    my $mod = $self->new(symbol => $sym, %kwargs);
    $mod->_p->_arg_params($args);
    $mod->_p->_aux_params($auxs);
    $mod->params_initialized(1);
    if($load_optimizer_states)
    {
        $mod->_p->_preload_opt_states(sprintf('%s-%04d.states', $prefix, $epoch));
    }
    return $mod;
}

=head2 save_checkpoint

    Save current progress to a checkpoint.
    Use mx->callback->module_checkpoint as epoch_end_callback to save during training.

    Parameters
    ----------
    prefix : str
        The file prefix to checkpoint to
    epoch : int
        The current epoch number
    save_optimizer_states : bool
        Whether to save optimizer states for later training
=cut


method save_checkpoint(Str $prefix, Int $epoch, Bool $save_optimizer_states=0)
{
    $self->_symbol->save("$prefix-symbol.json");
    my $param_name = sprintf('%s-%04d.params', $prefix, $epoch);
    $self->save_params($param_name);
    AI::MXNet::Logging->info('Saved checkpoint to "%s"', $param_name);
    if($save_optimizer_states)
    {
        my $state_name = sprintf('%s-%04d.states', $prefix, $epoch);
        $self->save_optimizer_states($state_name);
        AI::MXNet::Logging->info('Saved optimizer state to "%s"', $state_name);
    }
}

=head2 model_save_checkpoint

    Checkpoint the model data into file.

    Parameters
    ----------
    prefix : str
        Prefix of model name.
    epoch : int
        The epoch number of the model.
    symbol : AI::MXNet::Symbol
        The input symbol
    arg_params : hash ref of str to AI::MXNet::NDArray
        Model parameter, hash ref of name to AI::MXNet::NDArray of net's weights.
    aux_params : hash ref of str to NDArray
        Model parameter, hash ref of name to AI::MXNet::NDArray of net's auxiliary states.
    Notes
    -----
    - prefix-symbol.json will be saved for symbol.
    - prefix-epoch.params will be saved for parameters.
=cut

method model_save_checkpoint(
    Str                         $prefix,
    Int                         $epoch,
    Maybe[AI::MXNet::Symbol]    $symbol,
    HashRef[AI::MXNet::NDArray] $arg_params,
    HashRef[AI::MXNet::NDArray] $aux_params
)
{
    if(defined $symbol)
    {
        $symbol->save("$prefix-symbol.json");
    }
    my $param_name = sprintf('%s-%04d.params', $prefix, $epoch);
    $self->save_params($param_name, $arg_params, $aux_params);
    AI::MXNet::Logging->info('Saved checkpoint to "%s"', $param_name);
}

# Internal function to reset binded state.
method _reset_bind()
{
    $self->binded(0);
    $self->_p->_exec_group(undef);
    $self->_p->_data_shapes(undef);
    $self->_p->_label_shapes(undef);
}

method data_names()
{
    return $self->_p->_data_names;
}

method label_names()
{
    return $self->_p->_label_names;
}

method output_names()
{
    return $self->_p->_output_names;
}

method data_shapes()
{
    assert($self->binded);
    return $self->_p->_data_shapes;
}

method label_shapes()
{
    assert($self->binded);
    return $self->_p->_label_shapes;
}

method output_shapes()
{
    assert($self->binded);
    return $self->_p->_exec_group->get_output_shapes;
}

method get_params()
{
    assert($self->binded and $self->params_initialized);
    if($self->_p->_params_dirty)
    {
        $self->_sync_params_from_devices();
    }
    return ($self->_p->_arg_params, $self->_p->_aux_params);
}

method init_params(
    Maybe[AI::MXNet::Initializer]      :$initializer=AI::MXNet::Initializer->Uniform(scale => 0.01),
    Maybe[HashRef[AI::MXNet::NDArray]] :$arg_params=,
    Maybe[HashRef[AI::MXNet::NDArray]] :$aux_params=,

lib/AI/MXNet/Module.pm  view on Meta::CPAN

            $self->_p->_exec_group->param_names
        );
    }
}

method get_outputs(Bool $merge_multi_context=1)
{
    assert($self->binded and $self->params_initialized);
    return $self->_p->_exec_group->get_outputs($merge_multi_context);
}

method get_input_grads(Bool $merge_multi_context=1)
{
    assert($self->binded and $self->params_initialized and $self->inputs_need_grad);
    return $self->_p->_exec_group->get_input_grads($merge_multi_context);
}

method get_states(Bool $merge_multi_context=1)
{
    assert($self->binded and $self->params_initialized);
    return $self->_p->_exec_group->get_states($merge_multi_context);
}

method set_states(:$states=, :$value=)
{
    assert($self->binded and $self->params_initialized);
    return $self->_p->_exec_group->set_states($states, $value);
}

method update_metric(
    AI::MXNet::EvalMetric $eval_metric,
    ArrayRef[AI::MXNet::NDArray] $labels
)
{
    $self->_p->_exec_group->update_metric($eval_metric, $labels);
}

=head2 _sync_params_from_devices

    Synchronize parameters from devices to CPU. This function should be called after
    calling 'update' that updates the parameters on the devices, before one can read the
    latest parameters from $self->_arg_params and $self->_aux_params.
=cut

method _sync_params_from_devices()
{
    $self->_p->_exec_group->get_params($self->_p->_arg_params, $self->_p->_aux_params);
    $self->_p->_params_dirty(0);
}

method save_optimizer_states(Str $fname)
{
    assert($self->optimizer_initialized);
    if($self->_p->_update_on_kvstore)
    {
        $self->_p->_kvstore->save_optimizer_states($fname);
    }
    else
    {
        open(F, ">:raw", "$fname") or confess("can't open $fname for writing: $!");
        print F $self->_p->_updater->get_states();
        close(F);
    }
}

method load_optimizer_states(Str $fname)
{
    assert($self->optimizer_initialized);
    if($self->_p->_update_on_kvstore)
    {
        $self->_p->_kvstore->load_optimizer_states($fname);
    }
    else
    {
        open(F, "<:raw", "$fname") or confess("can't open $fname for reading: $!");
        my $data;
        { local($/) = undef; $data = <F>; }
        close(F);
        $self->_p->_updater->set_states($data);
    }
}

method install_monitor(AI::MXNet::Monitor $mon)
{
    assert($self->binded);
    $self->_p->_exec_group->install_monitor($mon);
}

method _updater()
{
    $self->_p->_updater;
}

method _kvstore()
{
    $self->_p->_kvstore;
}

1;



( run in 0.859 second using v1.01-cache-2.11-cpan-39bf76dae61 )