AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Module/Base.pm view on Meta::CPAN
$merge_multi_context=1 : Bool
=cut
method get_input_grads(Bool $merge_multi_context=1) { confess("NotImplemented") }
=head2 update
Update parameters according to the installed optimizer and the gradients computed
in the previous forward-backward batch.
=cut
method update() { confess("NotImplemented") }
=head2 update_metric
Evaluate and accumulate evaluation metric on outputs of the last forward computation.
Parameters
----------
$eval_metric : EvalMetric
$labels : ArrayRef[AI::MXNet::NDArray]
Typically $data_batch->label.
=cut
method update_metric(EvalMetric $eval_metric, ArrayRef[AI::MXNet::NDArray] $labels)
{
confess("NotImplemented")
}
################################################################################
# module setup
################################################################################
=head2 bind
Binds the symbols in order to construct the executors. This is necessary
before the computations can be performed.
Parameters
----------
$data_shapes : ArrayRef[AI::MXNet::DataDesc]
Typically is $data_iter->provide_data.
:$label_shapes= : Maybe[ArrayRef[AI::MXNet::DataDesc]]
Typically is $data_iter->provide_label.
:$for_training=1 : Bool
Default is 1. Whether the executors should be bind for training.
:$inputs_need_grad=0 : Bool
Default is 0. Whether the gradients to the input data need to be computed.
Typically this is not needed. But this might be needed when implementing composition
of modules.
:$force_rebind=0 : Bool
Default is 0. This function does nothing if the executors are already
binded. But with this as 1, the executors will be forced to rebind.
:$shared_module= : A subclass of AI::MXNet::Module::Base
Default is undef. This is used in bucketing. When not undef, the shared module
essentially corresponds to a different bucket -- a module with different symbol
but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
:$grad_req='write' : Str|ArrayRef[Str]|HashRef[Str]
Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
(defaults to 'write').
Can be specified globally (str) or for each argument (array ref, hash ref).
=cut
method bind(
ArrayRef[AI::MXNet::DataDesc] $data_shapes,
Maybe[ArrayRef[AI::MXNet::DataDesc]] :$label_shapes=,
Bool :$for_training=1,
Bool :$inputs_need_grad=0,
Bool :$force_rebind=0,
Maybe[AI::MXNet::BaseModule] :$shared_module=,
Str|ArrayRef[Str]|HashRef[Str] :$grad_req='write'
)
{
confess("NotImplemented")
}
=head2 init_optimizer
Install and initialize optimizers.
Parameters
----------
:$kvstore='local' : str or KVStore
:$optimizer='sgd' : str or Optimizer
:$optimizer_params={ learning_rate => 0.01 } : hash ref
:$force_init=0 : Bool
=cut
method init_optimizer(
Str :$kvstore='local',
Optimizer :$optimizer='sgd',
HashRef :$optimizer_params={ learning_rate => 0.01 },
Bool :$force_init=0
)
{
confess("NotImplemented")
}
################################################################################
# misc
################################################################################
=head2 symbol
The symbol associated with this module.
Except for AI::MXNet::Module, for other types of modules (e.g. AI::MXNet::Module::Bucketing), this
property might not be a constant throughout its life time. Some modules might
not even be associated with any symbols.
=cut
method symbol()
{
return $self->_symbol;
}
1;
( run in 0.674 second using v1.01-cache-2.11-cpan-39bf76dae61 )