AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Module/Bucketing.pm  view on Meta::CPAN

}

method init_params(
    AI::MXNet::Initializer             :$initializer=AI::MXNet::Initializer->Uniform(scale => 0.01),
    Maybe[HashRef[AI::MXNet::NDArray]] :$arg_params=,
    Maybe[HashRef[AI::MXNet::NDArray]] :$aux_params=,
    Bool                               :$allow_missing=0,
    Bool                               :$force_init=0,
    Bool                               :$allow_extra=0
)
{
    return if($self->params_initialized and not $force_init);
    assert($self->binded, 'call bind before initializing the parameters');
    $self->_curr_module->init_params(
        initializer   => $initializer,
        arg_params    => $arg_params,
        aux_params    => $aux_params,
        allow_missing => $allow_missing,
        force_init    => $force_init,
        allow_extra   => $allow_extra
    );
    $self->_params_dirty(0);
    $self->params_initialized(1);
}

method get_states(Bool $merge_multi_context=1)
{
    assert($self->binded and $self->params_initialized);
    $self->_curr_module->get_states($merge_multi_context);
}

method set_states(:$states=, :$value=)
{
    assert($self->binded and $self->params_initialized);
    $self->_curr_module->set_states(states => $states, value => $value);
}

=head2 bind

    Binding for a AI::MXNet::Module::Bucketing means setting up the buckets and bind the
    executor for the default bucket key. Executors corresponding to other keys are
    binded afterwards with switch_bucket.

    Parameters
    ----------
    :$data_shapes : ArrayRef[AI::MXNet::DataDesc|NameShape]
        This should correspond to the symbol for the default bucket.
    :$label_shapes= : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
        This should correspond to the symbol for the default bucket.
    :$for_training : Bool
        Default is 1.
    :$inputs_need_grad : Bool
        Default is 0.
    :$force_rebind : Bool
        Default is 0.
    :$shared_module : AI::MXNet::Module::Bucketing
        Default is undef. This value is currently not used.
    :$grad_req : str, array ref of str, hash ref of str to str
        Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
        (defaults to 'write').
        Can be specified globally (str) or for each argument (array ref, hash ref).
    :$bucket_key : str
        bucket key for binding. by default is to use the ->default_bucket_key
=cut

method bind(
    ArrayRef[AI::MXNet::DataDesc|NameShape]                   :$data_shapes,
    Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]            :$label_shapes=,
    Bool                                                      :$for_training=1,
    Bool                                                      :$inputs_need_grad=0,
    Bool                                                      :$force_rebind=0,
    Maybe[AI::MXNet::BaseModule]                              :$shared_module=,
    Str|ArrayRef[Str]|HashRef[Str]                            :$grad_req='write',
    Maybe[Str]                                                :$bucket_key=
)
{
    # in case we already initialized params, keep it
    my ($arg_params, $aux_params);
    if($self->params_initialized)
    {
        ($arg_params, $aux_params) = $self->get_params;
    }

    # force rebinding is typically used when one want to switch from
    # training to prediction phase.
    $self->_reset_bind if $force_rebind;

    if($self->binded)
    {
        $self->logger->warning('Already binded, ignoring bind()');
        return;
    }

    assert((not defined $shared_module), 'shared_module for BucketingModule is not supported');

    $self->for_training($for_training);
    $self->inputs_need_grad($inputs_need_grad);
    $self->binded(1);

    my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($bucket_key//$self->_default_bucket_key);
    my $module = AI::MXNet::Module->new(
            symbol            => $symbol,
            data_names        => $data_names,
            label_names       => $label_names,
            logger            => $self->logger,
            context           => $self->_context,
            work_load_list    => $self->_work_load_list,
            state_names       => $self->_state_names,
            fixed_param_names => $self->_fixed_param_names
    );
    $module->bind(
        data_shapes      => $data_shapes,
        label_shapes     => $label_shapes,
        for_training     => $for_training,
        inputs_need_grad => $inputs_need_grad,
        force_rebind     => 0,
        shared_module    => undef,
        grad_req         => $grad_req
    );
    $self->_curr_module($module);
    $self->_curr_bucket_key($self->_default_bucket_key);



( run in 0.398 second using v1.01-cache-2.11-cpan-39bf76dae61 )