AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Module/Base.pm view on Meta::CPAN
- get_outputs(): get outputs of the previous forward operation.
- get_input_grads(): get the gradients with respect to the inputs computed
in the previous backward operation.
- update_metric(metric, labels): update performance metric for the previous forward
computed results.
- other properties (mostly for backward compatability)
- symbol: the underlying symbolic graph for this module (if any)
This property is not necessarily constant. For example, for AI::MXNet::Module::Bucketing,
this property is simply the *current* symbol being used. For other modules,
this value might not be well defined.
When those intermediate-level API are implemented properly, the following
high-level API will be automatically available for a module:
- fit: train the module parameters on a data set
- predict: run prediction on a data set and collect outputs
- score: run prediction on a data set and evaluate performance
=cut
has 'logger' => (is => 'rw', default => sub { AI::MXNet::Logging->get_logger });
has '_symbol' => (is => 'rw', init_arg => 'symbol', isa => 'AI::MXNet::Symbol');
has [
qw/binded for_training inputs_need_grad
params_initialized optimizer_initialized/
] => (is => 'rw', isa => 'Bool', init_arg => undef, default => 0);
################################################################################
# High Level API
################################################################################
=head2 forward_backward
A convenient function that calls both forward and backward.
=cut
method forward_backward(AI::MXNet::DataBatch $data_batch)
{
$self->forward($data_batch, is_train => 1);
$self->backward();
}
=head2 score
Run prediction on eval_data and evaluate the performance according to
eval_metric.
Parameters
----------
$eval_data : AI::MXNet::DataIter
$eval_metric : AI::MXNet::EvalMetric
:$num_batch= : Maybe[Int]
Number of batches to run. Default is undef, indicating run until the AI::MXNet::DataIter
finishes.
:$batch_end_callback= : Maybe[Callback]
Could also be a array ref of functions.
:$reset=1 : Bool
Default 1, indicating whether we should reset $eval_data before starting
evaluating.
$epoch=0 : Int
Default is 0. For compatibility, this will be passed to callbacks (if any). During
training, this will correspond to the training epoch number.
=cut
method score(
AI::MXNet::DataIter $eval_data,
EvalMetric $eval_metric,
Maybe[Int] :$num_batch=,
Maybe[Callback]|ArrayRef[Callback] :$batch_end_callback=,
Maybe[Callback]|ArrayRef[Callback] :$score_end_callback=,
Bool :$reset=1,
Int :$epoch=0
)
{
assert($self->binded and $self->params_initialized);
$eval_data->reset if $reset;
if(not blessed $eval_metric or not $eval_metric->isa('AI::MXNet::EvalMetric'))
{
$eval_metric = AI::MXNet::Metric->create($eval_metric);
}
$eval_metric->reset();
my $actual_num_batch = 0;
my $nbatch = 0;
while(my $eval_batch = <$eval_data>)
{
last if (defined $num_batch and $nbatch == $num_batch);
$self->forward($eval_batch, is_train => 0);
$self->update_metric($eval_metric, $eval_batch->label);
if (defined $batch_end_callback)
{
my $batch_end_params = AI::MXNet::BatchEndParam->new(
epoch => $epoch,
nbatch => $nbatch,
eval_metric => $eval_metric
);
for my $callback (@{ _as_list($batch_end_callback) })
{
&{$callback}($batch_end_params);
}
}
$actual_num_batch++;
$nbatch++
}
if($score_end_callback)
{
my $params = AI::MXNet::BatchEndParam->new(
epoch => $epoch,
nbatch => $actual_num_batch,
eval_metric => $eval_metric,
);
for my $callback (@{ _as_list($score_end_callback) })
{
&{callback}($params);
}
}
return $eval_metric->get_name_value;
}
=head2 iter_predict
( run in 2.391 seconds using v1.01-cache-2.11-cpan-63c85eba8c4 )