AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Module.pm  view on Meta::CPAN

{
    enumerate(sub{
        my ($index, $arg_list, $grad_list) = @_;
        if(not defined $grad_list->[0])
        {
            return;
        }
        if($kvstore)
        {
            my $name = $param_names->[$index];
            # push gradient, priority is negative index
            $kvstore->push($name, $grad_list, priority => -$index);
            # pull back the sum gradients, to the same locations.
            $kvstore->pull($name, out => $grad_list, priority => -$index);
        }
        enumerate(sub {
            my ($k, $w, $g) = @_;
            # faked an index here, to make optimizer create diff
            # state for the same index but on diff devs, TODO(mli)
            # use a better solution later
            &{$updater}($index*$num_device+$k, $g, $w);
        }, $arg_list, $grad_list);
    }, $param_arrays, $grad_arrays);
}

method load_checkpoint(Str $prefix, Int $epoch)
{
    my $symbol = AI::MXNet::Symbol->load("$prefix-symbol.json");
    my %save_dict = %{ AI::MXNet::NDArray->load(sprintf('%s-%04d.params', $prefix, $epoch)) };
    my %arg_params;
    my %aux_params;
    while(my ($k, $v) = each %save_dict)
    {
        my ($tp, $name) = split(/:/, $k, 2);
        if($tp eq 'arg')
        {
            $arg_params{$name} = $v;
        }
        if($tp eq 'aux')
        {
            $aux_params{$name} = $v;
        }
    }
    return ($symbol, \%arg_params, \%aux_params);
}

=head1 NAME

    AI::MXNet::Module - FeedForward interface of MXNet.
    See AI::MXNet::Module::Base for the details.
=cut

extends 'AI::MXNet::Module::Base';

has '_symbol'           => (is => 'ro', init_arg => 'symbol', isa => 'AI::MXNet::Symbol', required => 1);
has '_data_names'       => (is => 'ro', init_arg => 'data_names', isa => 'ArrayRef[Str]');
has '_label_names'      => (is => 'ro', init_arg => 'label_names', isa => 'Maybe[ArrayRef[Str]]');
has 'work_load_list'    => (is => 'rw', isa => 'Maybe[ArrayRef[Int]]');
has 'fixed_param_names' => (is => 'rw', isa => 'Maybe[ArrayRef[Str]]');
has 'state_names'       => (is => 'rw', isa => 'Maybe[ArrayRef[Str]]');
has 'logger'            => (is => 'ro', default => sub { AI::MXNet::Logging->get_logger });
has '_p'                => (is => 'rw', init_arg => undef);
has 'context'           => (
    is => 'ro', 
    isa => 'AI::MXNet::Context|ArrayRef[AI::MXNet::Context]',
    default => sub { AI::MXNet::Context->cpu }
);

around BUILDARGS => sub {
    my $orig  = shift;
    my $class = shift;
    if(@_%2)
    {
        my $symbol = shift;
        return $class->$orig(symbol => $symbol, @_);
    }
    return $class->$orig(@_);
};

sub BUILD
{
    my $self = shift;
    $self->_p(AI::MXNet::Module::Private->new);
    my $context = $self->context;
    if(blessed $context)
    {
        $context = [$context];
    }
    $self->_p->_context($context);
    my $work_load_list = $self->work_load_list;
    if(not defined $work_load_list)
    {
        $work_load_list = [(1)x@{$self->_p->_context}];
    }
    assert(@{ $work_load_list } == @{ $self->_p->_context });
    $self->_p->_work_load_list($work_load_list);
    my @data_names  = @{ $self->_data_names//['data'] };
    my @label_names = @{ $self->_label_names//['softmax_label'] };
    my @state_names = @{ $self->state_names//[] };
    my $arg_names   = $self->_symbol->list_arguments;
    my @input_names = (@data_names, @label_names, @state_names);
    my %input_names = map { $_ => 1 } @input_names;
    $self->_p->_param_names([grep { not exists $input_names{$_} } @{ $arg_names }]);
    $self->_p->_fixed_param_names($self->fixed_param_names//[]);
    $self->_p->_state_names(\@state_names);
    $self->_p->_aux_names($self->_symbol->list_auxiliary_states);
    $self->_p->_data_names(\@data_names);
    $self->_p->_label_names(\@label_names);
    $self->_p->_output_names($self->_symbol->list_outputs);
    $self->_p->_params_dirty(0);
    $self->_check_input_names($self->_symbol, $self->_p->_data_names, "data", 1);
    $self->_check_input_names($self->_symbol, $self->_p->_label_names, "label", 0);
    $self->_check_input_names($self->_symbol, $self->_p->_state_names, "state", 1);
    $self->_check_input_names($self->_symbol, $self->_p->_fixed_param_names, "fixed_param", 1);
}

method Module(@args) { return @args ?  __PACKAGE__->new(@args) : __PACKAGE__ }
method BucketingModule(@args) { return AI::MXNet::Module::Bucketing->new(@args) }

=head2 load

        Create a model from previously saved checkpoint.

        Parameters
        ----------
        prefix : str
            path prefix of saved model files. You should have
            "prefix-symbol.json", "prefix-xxxx.params", and
            optionally "prefix-xxxx.states", where xxxx is the
            epoch number.
        epoch : int
            epoch to load.
        load_optimizer_states : bool
            whether to load optimizer states. Checkpoint needs
            to have been made with save_optimizer_states=True.
        data_names : array ref of str
            Default is ['data'] for a typical model used in image classification.
        label_names : array ref of str
            Default is ['softmax_label'] for a typical model used in image
            classification.
        logger : Logger
            Default is AI::MXNet::Logging.
        context : Context or list of Context
            Default is cpu(0).
        work_load_list : array ref of number
            Default is undef, indicating an uniform workload.
        fixed_param_names: array ref of str
            Default is undef, indicating no network parameters are fixed.
=cut

method load(
    Str $prefix,
    Int $epoch,
    Bool $load_optimizer_states=0,
    %kwargs
)
{
    my ($sym, $args, $auxs) = __PACKAGE__->load_checkpoint($prefix, $epoch);
    my $mod = $self->new(symbol => $sym, %kwargs);
    $mod->_p->_arg_params($args);
    $mod->_p->_aux_params($auxs);
    $mod->params_initialized(1);
    if($load_optimizer_states)
    {
        $mod->_p->_preload_opt_states(sprintf('%s-%04d.states', $prefix, $epoch));
    }
    return $mod;
}

=head2 save_checkpoint

    Save current progress to a checkpoint.
    Use mx->callback->module_checkpoint as epoch_end_callback to save during training.

    Parameters
    ----------
    prefix : str
        The file prefix to checkpoint to
    epoch : int
        The current epoch number
    save_optimizer_states : bool
        Whether to save optimizer states for later training
=cut


method save_checkpoint(Str $prefix, Int $epoch, Bool $save_optimizer_states=0)
{
    $self->_symbol->save("$prefix-symbol.json");
    my $param_name = sprintf('%s-%04d.params', $prefix, $epoch);
    $self->save_params($param_name);
    AI::MXNet::Logging->info('Saved checkpoint to "%s"', $param_name);
    if($save_optimizer_states)
    {
        my $state_name = sprintf('%s-%04d.states', $prefix, $epoch);
        $self->save_optimizer_states($state_name);
        AI::MXNet::Logging->info('Saved optimizer state to "%s"', $state_name);
    }
}

=head2 model_save_checkpoint

lib/AI/MXNet/Module.pm  view on Meta::CPAN

    }

    if($self->params_initialized and not $force_init)
    {
        AI::MXNet::Logging->warning(
            "Parameters already initialized and force_init=False. "
            ."set_params call ignored."
        );
        return;
    }
    $self->_p->_exec_group->set_params($arg_params, $aux_params, $allow_extra);
    $self->_p->_params_dirty(1);
    $self->params_initialized(1);
}

=head2 bind

    Bind the symbols to construct executors. This is necessary before one
    can perform computation with the module.

    Parameters
    ----------
    :$data_shapes : ArrayRef[AI::MXNet::DataDesc|NameShape]
        Typically is $data_iter->provide_data.
    :$label_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
        Typically is $data_iter->provide_label.
    :$for_training : bool
        Default is 1. Whether the executors should be bind for training.
    :$inputs_need_grad : bool
        Default is 0. Whether the gradients to the input data need to be computed.
        Typically this is not needed. But this might be needed when implementing composition
        of modules.
    :$force_rebind : bool
        Default is 0. This function does nothing if the executors are already
        binded. But with this 1, the executors will be forced to rebind.
    :$shared_module : Module
        Default is undef. This is used in bucketing. When not undef, the shared module
        essentially corresponds to a different bucket -- a module with different symbol
        but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
=cut

method bind(
    ArrayRef[AI::MXNet::DataDesc|NameShape]        :$data_shapes,
    Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] :$label_shapes=,
    Bool                                           :$for_training=1,
    Bool                                           :$inputs_need_grad=0,
    Bool                                           :$force_rebind=0,
    Maybe[AI::MXNet::Module]                       :$shared_module=,
    GradReq|HashRef[GradReq]|ArrayRef[GradReq]     :$grad_req='write',
    Maybe[ArrayRef[Str]]                           :$state_names=$self->_p->_state_names
)
{
    # force rebinding is typically used when one want to switch from
    # training to prediction phase.
    if($force_rebind)
    {
        $self->_reset_bind();
    }
    if($self->binded)
    {
        $self->logger->warning('Already binded, ignoring bind()');
        return;
    }
    $self->for_training($for_training);
    $self->inputs_need_grad($inputs_need_grad);
    $self->binded(1);
    $self->_p->_grad_req($grad_req);

    if(not $for_training)
    {
        assert(not $inputs_need_grad);
    }
    ($data_shapes, $label_shapes) = $self->_parse_data_desc(
        $self->data_names, $self->label_names, $data_shapes, $label_shapes
    );
    $self->_p->_data_shapes($data_shapes);
    $self->_p->_label_shapes($label_shapes);
    my $shared_group;
    if($shared_module)
    {
        assert($shared_module->binded and $shared_module->params_initialized);
        $shared_group = $shared_module->_p->_exec_group;
    }

    $self->_p->_exec_group(
        AI::MXNet::DataParallelExecutorGroup->new(
            symbol            => $self->_symbol,
            contexts          => $self->_p->_context,
            workload          => $self->_p->_work_load_list,
            data_shapes       => $self->_p->_data_shapes,
            label_shapes      => $self->_p->_label_shapes,
            param_names       => $self->_p->_param_names,
            state_names       => $state_names,
            for_training      => $for_training,
            inputs_need_grad  => $inputs_need_grad,
            shared_group      => $shared_group,
            logger            => $self->logger,
            fixed_param_names => $self->_p->_fixed_param_names,
            grad_req          => $grad_req
        )
    );
    if($shared_module)
    {
        $self->params_initialized(1);
        $self->_p->_arg_params($shared_module->_p->_arg_params);
        $self->_p->_aux_params($shared_module->_p->_aux_params);
    }
    elsif($self->params_initialized)
    {
        # if the parameters are already initialized, we are re-binding
        # so automatically copy the already initialized params
        $self->_p->_exec_group->set_params($self->_p->_arg_params, $self->_p->_aux_params);
    }
    else
    {
        assert(not defined $self->_p->_arg_params and not $self->_p->_aux_params);
        my @param_arrays = (
            map { AI::MXNet::NDArray->zeros($_->[0]->shape, dtype => $_->[0]->dtype) }
            @{ $self->_p->_exec_group->_p->param_arrays }
        );
        my %arg_params;
        @arg_params{ @{ $self->_p->_param_names } } = @param_arrays;
        $self->_p->_arg_params(\%arg_params);
        my @aux_arrays = (
            map { AI::MXNet::NDArray->zeros($_->[0]->shape, dtype => $_->[0]->dtype) }
            @{ $self->_p->_exec_group->_p->aux_arrays }
        );
        my %aux_params;
        @aux_params{ @{ $self->_p->_aux_names } } = @aux_arrays;
        $self->_p->_aux_params(\%aux_params);
    }
    if($shared_module and $shared_module->optimizer_initialized)
    {
        $self->borrow_optimizer($shared_module)
    }
}

=head2 reshape

    Reshape the module for new input shapes.
    Parameters
    ----------
    :$data_shapes : ArrayRef[AI::MXNet::DataDesc]
        Typically is $data_iter->provide_data.
    :$label_shapes= : Maybe[ArrayRef[AI::MXNet::DataDesc]]
        Typically is $data_iter->provide_label.
=cut

method reshape(
    ArrayRef[AI::MXNet::DataDesc|NameShape]        :$data_shapes,
    Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] :$label_shapes=
)
{
    assert($self->binded);
    ($data_shapes, $label_shapes) = $self->_parse_data_desc(
        $self->data_names, $self->label_names, $data_shapes, $label_shapes
    );
    $self->_p->_data_shapes($data_shapes);
    $self->_p->_label_shapes($label_shapes);
    $self->_p->_exec_group->reshape($self->_p->_data_shapes, $self->_p->_label_shapes);
}

method init_optimizer(
    Str|AI::MXNet::KVStore :$kvstore='local',
    Optimizer              :$optimizer='sgd',
    HashRef                :$optimizer_params={ learning_rate => 0.01 },
    Bool                   :$force_init=0
)
{
    assert($self->binded and $self->params_initialized);
    if($self->optimizer_initialized and not $force_init)
    {
        $self->logger->warning('optimizer already initialized, ignoring...');
        return;
    }
    if($self->_p->_params_dirty)
    {
        $self->_sync_params_from_devices;
    }

    my ($kvstore, $update_on_kvstore) = _create_kvstore(
        $kvstore,
        scalar(@{$self->_p->_context}),
        $self->_p->_arg_params
    );
    my $batch_size = $self->_p->_exec_group->_p->batch_size;
    if($kvstore and $kvstore->type =~ /dist/ and $kvstore->type =~ /_sync/)
    {
        $batch_size *= $kvstore->num_workers;
    }
    my $rescale_grad = 1/$batch_size;

    if(not blessed $optimizer)
    {
        my %idx2name;
        if($update_on_kvstore)
        {
            @idx2name{ 0..@{$self->_p->_exec_group->param_names}-1 } = @{$self->_p->_exec_group->param_names};
        }
        else
        {
            for my $k (0..@{$self->_p->_context}-1)
            {
                @idx2name{ map { $_ + $k } 0..@{$self->_p->_exec_group->param_names}-1 } = @{$self->_p->_exec_group->param_names};
            }
        }
        if(not exists $optimizer_params->{rescale_grad})
        {
            $optimizer_params->{rescale_grad} = $rescale_grad;
        }
        $optimizer = AI::MXNet::Optimizer->create(
            $optimizer,
            sym  => $self->symbol,
            param_idx2name => \%idx2name,
            %{ $optimizer_params }
        );
        if($optimizer->rescale_grad != $rescale_grad)
        {
            AI::MXNet::Logging->warning(
                "Optimizer created manually outside Module but rescale_grad "
                ."is not normalized to 1.0/batch_size/num_workers (%s vs. %s). "
                ."Is this intended?",
                $optimizer->rescale_grad, $rescale_grad
            );
        }
    }

    $self->_p->_optimizer($optimizer);
    $self->_p->_kvstore($kvstore);
    $self->_p->_update_on_kvstore($update_on_kvstore);
    $self->_p->_updater(undef);

    if($kvstore)



( run in 0.691 second using v1.01-cache-2.11-cpan-39bf76dae61 )