AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Module/Bucketing.pm  view on Meta::CPAN

    Str|ArrayRef[Str]|HashRef[Str]                            :$grad_req='write',
    Maybe[Str]                                                :$bucket_key=
)
{
    # in case we already initialized params, keep it
    my ($arg_params, $aux_params);
    if($self->params_initialized)
    {
        ($arg_params, $aux_params) = $self->get_params;
    }

    # force rebinding is typically used when one want to switch from
    # training to prediction phase.
    $self->_reset_bind if $force_rebind;

    if($self->binded)
    {
        $self->logger->warning('Already binded, ignoring bind()');
        return;
    }

    assert((not defined $shared_module), 'shared_module for BucketingModule is not supported');

    $self->for_training($for_training);
    $self->inputs_need_grad($inputs_need_grad);
    $self->binded(1);

    my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($bucket_key//$self->_default_bucket_key);
    my $module = AI::MXNet::Module->new(
            symbol            => $symbol,
            data_names        => $data_names,
            label_names       => $label_names,
            logger            => $self->logger,
            context           => $self->_context,
            work_load_list    => $self->_work_load_list,
            state_names       => $self->_state_names,
            fixed_param_names => $self->_fixed_param_names
    );
    $module->bind(
        data_shapes      => $data_shapes,
        label_shapes     => $label_shapes,
        for_training     => $for_training,
        inputs_need_grad => $inputs_need_grad,
        force_rebind     => 0,
        shared_module    => undef,
        grad_req         => $grad_req
    );
    $self->_curr_module($module);
    $self->_curr_bucket_key($self->_default_bucket_key);
    $self->_buckets->{ $self->_default_bucket_key } = $module;

    # copy back saved params, if already initialized
    if($self->params_initialized)
    {
        $self->set_params($arg_params, $aux_params);
    }
}

=head2 switch_bucket

    Switch to a different bucket. This will change $self->_curr_module.

    Parameters
    ----------
    :$bucket_key : str (or any perl object that overloads "" op)
        The key of the target bucket.
    :$data_shapes :  Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
        Typically $data_batch->provide_data.
    :$label_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
        Typically $data_batch->provide_label.
=cut

method switch_bucket(
    Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]            :$data_shapes=,
    Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]            :$label_shapes=,
                                                              :$bucket_key
)
{
    assert($self->binded, 'call bind before switching bucket');
    if(not exists $self->_buckets->{ $bucket_key })
    {
        my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($bucket_key);
        my $module = AI::MXNet::Module->new(
            symbol         => $symbol,
            data_names     => $data_names,
            label_names    => $label_names,
            logger         => $self->logger,
            context        => $self->_context,
            work_load_list => $self->_work_load_list
        );
        $module->bind(
            data_shapes      => $data_shapes,
            label_shapes     => $label_shapes,
            for_training     => $self->_curr_module->for_training,
            inputs_need_grad => $self->_curr_module->inputs_need_grad,
            force_rebind     => 0,
            shared_module    => $self->_buckets->{ $self->_default_bucket_key },
        );
        $self->_buckets->{ $bucket_key } = $module;
    }
    $self->_curr_module($self->_buckets->{ $bucket_key });
    $self->_curr_bucket_key($bucket_key);
}

method init_optimizer(
    Str        :$kvstore='local',
    Optimizer  :$optimizer='sgd',
    HashRef    :$optimizer_params={ learning_rate => 0.01 },
    Bool       :$force_init=0
)
{
    assert($self->binded and $self->params_initialized);
    if($self->optimizer_initialized and not $force_init)
    {
        $self->logger->warning('optimizer already initialized, ignoring.');
        return;
    }

    $self->_curr_module->init_optimizer(
        kvstore           => $kvstore,
        optimizer         => $optimizer,



( run in 0.435 second using v1.01-cache-2.11-cpan-39bf76dae61 )