AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Module/Base.pm  view on Meta::CPAN

    $merge_multi_context=1 : Bool
=cut

method get_input_grads(Bool $merge_multi_context=1) { confess("NotImplemented") }

=head2 update

    Update parameters according to the installed optimizer and the gradients computed
    in the previous forward-backward batch.
=cut

method update() { confess("NotImplemented") }

=head2 update_metric

    Evaluate and accumulate evaluation metric on outputs of the last forward computation.

    Parameters
    ----------
    $eval_metric : EvalMetric
    $labels : ArrayRef[AI::MXNet::NDArray]
        Typically $data_batch->label.
=cut

method update_metric(EvalMetric $eval_metric, ArrayRef[AI::MXNet::NDArray] $labels)
{
    confess("NotImplemented")
}

################################################################################
# module setup
################################################################################

=head2 bind

    Binds the symbols in order to construct the executors. This is necessary
    before the computations can be performed.

    Parameters
    ----------
    $data_shapes : ArrayRef[AI::MXNet::DataDesc]
        Typically is $data_iter->provide_data.
    :$label_shapes= : Maybe[ArrayRef[AI::MXNet::DataDesc]]
        Typically is $data_iter->provide_label.
    :$for_training=1 : Bool
        Default is 1. Whether the executors should be bind for training.
    :$inputs_need_grad=0 : Bool
        Default is 0. Whether the gradients to the input data need to be computed.
        Typically this is not needed. But this might be needed when implementing composition
        of modules.
    :$force_rebind=0 : Bool
        Default is 0. This function does nothing if the executors are already
        binded. But with this as 1, the executors will be forced to rebind.
    :$shared_module= : A subclass of AI::MXNet::Module::Base
        Default is undef. This is used in bucketing. When not undef, the shared module
        essentially corresponds to a different bucket -- a module with different symbol
        but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
    :$grad_req='write' : Str|ArrayRef[Str]|HashRef[Str]
        Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
        (defaults to 'write').
        Can be specified globally (str) or for each argument (array ref, hash ref).
=cut

method bind(
    ArrayRef[AI::MXNet::DataDesc]         $data_shapes,
    Maybe[ArrayRef[AI::MXNet::DataDesc]] :$label_shapes=,
    Bool                                 :$for_training=1,
    Bool                                 :$inputs_need_grad=0,
    Bool                                 :$force_rebind=0,
    Maybe[AI::MXNet::BaseModule]         :$shared_module=,
    Str|ArrayRef[Str]|HashRef[Str]       :$grad_req='write'
)
{
    confess("NotImplemented")
}

=head2 init_optimizer

    Install and initialize optimizers.

    Parameters
    ----------
    :$kvstore='local' : str or KVStore
    :$optimizer='sgd' : str or Optimizer
    :$optimizer_params={ learning_rate => 0.01 } : hash ref
    :$force_init=0 : Bool
=cut

method init_optimizer(
    Str        :$kvstore='local',
    Optimizer  :$optimizer='sgd',
    HashRef    :$optimizer_params={ learning_rate => 0.01 },
    Bool       :$force_init=0
)
{
    confess("NotImplemented")
}

################################################################################
# misc
################################################################################

=head2 symbol

    The symbol associated with this module.

    Except for AI::MXNet::Module, for other types of modules (e.g. AI::MXNet::Module::Bucketing), this
    property might not be a constant throughout its life time. Some modules might
    not even be associated with any symbols.
=cut

method symbol()
{
    return $self->_symbol;
}

1;



( run in 0.674 second using v1.01-cache-2.11-cpan-39bf76dae61 )