AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Module/Base.pm view on Meta::CPAN
package AI::MXNet::BatchEndParam;
use Mouse;
use AI::MXNet::Function::Parameters;
has [qw/epoch nbatch/] => (is => 'rw', isa => 'Int');
has 'eval_metric' => (is => 'rw', isa => 'AI::MXNet::EvalMetric');
package AI::MXNet::Module::Base;
use Mouse;
use AI::MXNet::Base;
use Time::HiRes qw(time);
=head1 NAME
AI::MXNet::Module::Base - Base class for AI::MXNet::Module and AI::MXNet::Module::Bucketing
=cut
func _as_list($obj)
{
return [$obj] if ((ref($obj)//'') ne 'ARRAY');
return $obj;
}
# Check that all input names are in symbol's argument
method _check_input_names(
AI::MXNet::Symbol $symbol,
ArrayRef[Str] $names,
Str $typename,
Bool $throw
)
{
my @candidates;
my %args = map {
push @candidates, $_ if not /_(?:weight|bias|gamma|beta)$/;
$_ => 1
} @{ $symbol->list_arguments };
for my $name (@$names)
{
my $msg;
if(not exists $args{$name} and $name ne 'softmax_label')
{
$msg = sprintf("\033[91mYou created Module with Module(..., %s_names=%s) but "
."input with name '%s' is not found in symbol.list_arguments(). "
."Did you mean one of:\n\t%s\033[0m",
$typename, "@$names", $name, join("\n\t", @candidates)
);
if($throw)
{
confess($msg);
}
else
{
AI::MXNet::Logging->warning($msg);
}
}
}
}
# Check that input names matches input data descriptors
method _check_names_match(
ArrayRef[Str] $data_names,
ArrayRef[NameShapeOrDataDesc] $data_shapes,
Str $name,
Bool $throw
)
{
return if (not @$data_shapes and @$data_names == 1 and $data_names->[0] eq 'softmax_label');
my @actual = map { @{$_}[0] } @{ $data_shapes };
if("@$data_names" ne "@actual")
{
my $msg = sprintf(
"Data provided by %s_shapes don't match names specified by %s_names (%s vs. %s)",
$name, $name, "@$data_shapes", "@$data_names"
);
if($throw)
{
confess($msg);
}
else
{
AI::MXNet::Logging->warning($msg);
}
}
}
method _parse_data_desc(
ArrayRef[Str] $data_names,
Maybe[ArrayRef[Str]] $label_names,
ArrayRef[NameShapeOrDataDesc] $data_shapes,
Maybe[ArrayRef[NameShapeOrDataDesc]] $label_shapes
)
{
$data_shapes = [map { blessed $_ ? $_ : AI::MXNet::DataDesc->new(@$_) } @$data_shapes];
$self->_check_names_match($data_names, $data_shapes, 'data', 1);
if($label_shapes)
{
$label_shapes = [map { blessed $_ ? $_ : AI::MXNet::DataDesc->new(@$_) } @$label_shapes];
$self->_check_names_match($label_names, $label_shapes, 'label', 0);
}
else
{
$self->_check_names_match($label_names, [], 'label', 0);
}
return ($data_shapes, $label_shapes);
}
=head1 DESCRIPTION
The base class of a modules. A module represents a computation component. The design
purpose of a module is that it abstract a computation "machine", that one can run forward,
backward, update parameters, etc. We aim to make the APIs easy to use, especially in the
case when we need to use imperative API to work with multiple modules (e.g. stochastic
depth network).
A module has several states:
- Initial state. Memory is not allocated yet, not ready for computation yet.
- Binded. Shapes for inputs, outputs, and parameters are all known, memory allocated,
ready for computation.
- Parameter initialized. For modules with parameters, doing computation before initializing
the parameters might result in undefined outputs.
- Optimizer installed. An optimizer can be installed to a module. After this, the parameters
of the module can be updated according to the optimizer after gradients are computed
(forward-backward).
In order for a module to interact with others, a module should be able to report the
following information in its raw stage (before binded)
- data_names: array ref of string indicating the names of required data.
- output_names: array ref of string indicating the names of required outputs.
lib/AI/MXNet/Module/Base.pm view on Meta::CPAN
Maybe[EvalMetric] :$validation_metric=,
Maybe[AI::MXNet::Monitor] :$monitor=
)
{
$self->bind(
data_shapes => $train_data->provide_data,
label_shapes => $train_data->provide_label,
for_training => 1,
force_rebind => $force_rebind
);
if($monitor)
{
$self->install_monitor($monitor);
}
$self->init_params(
initializer => $initializer,
arg_params => $arg_params,
aux_params => $aux_params,
allow_missing => $allow_missing,
force_init => $force_init
);
$self->init_optimizer(
kvstore => $kvstore,
optimizer => $optimizer,
optimizer_params => $optimizer_params
);
if(not defined $validation_metric)
{
$validation_metric = $eval_metric;
}
$eval_metric = AI::MXNet::Metric->create($eval_metric)
unless blessed $eval_metric;
################################################################################
# training loop
################################################################################
for my $epoch ($begin_epoch..$num_epoch-1)
{
my $tic = time;
$eval_metric->reset;
my $nbatch = 0;
my $end_of_batch = 0;
my $next_data_batch = <$train_data>;
while(not $end_of_batch)
{
my $data_batch = $next_data_batch;
$monitor->tic if $monitor;
$self->forward_backward($data_batch);
$self->update;
$next_data_batch = <$train_data>;
if(defined $next_data_batch)
{
$self->prepare($next_data_batch);
}
else
{
$end_of_batch = 1;
}
$self->update_metric($eval_metric, $data_batch->label);
$monitor->toc_print if $monitor;
if(defined $batch_end_callback)
{
my $batch_end_params = AI::MXNet::BatchEndParam->new(
epoch => $epoch,
nbatch => $nbatch,
eval_metric => $eval_metric
);
for my $callback (@{ _as_list($batch_end_callback) })
{
&{$callback}($batch_end_params);
}
}
$nbatch++;
}
# one epoch of training is finished
my $name_value = $eval_metric->get_name_value;
while(my ($name, $val) = each %{ $name_value })
{
$self->logger->info('Epoch[%d] Train-%s=%f', $epoch, $name, $val);
}
my $toc = time;
$self->logger->info('Epoch[%d] Time cost=%.3f', $epoch, ($toc-$tic));
# sync aux params across devices
my ($arg_params, $aux_params) = $self->get_params;
$self->set_params($arg_params, $aux_params);
if($epoch_end_callback)
{
for my $callback (@{ _as_list($epoch_end_callback) })
{
&{$callback}($epoch, $self->get_symbol, $arg_params, $aux_params);
}
}
#----------------------------------------
# evaluation on validation set
if(defined $eval_data)
{
my $res = $self->score(
$eval_data,
$validation_metric,
score_end_callback => $eval_end_callback,
batch_end_callback => $eval_batch_end_callback,
epoch => $epoch
);
#TODO: pull this into default
while(my ($name, $val) = each %{ $res })
{
$self->logger->info('Epoch[%d] Validation-%s=%f', $epoch, $name, $val);
}
}
# end of 1 epoch, reset the data-iter for another epoch
$train_data->reset;
}
}
################################################################################
# Symbol information
################################################################################
( run in 0.820 second using v1.01-cache-2.11-cpan-39bf76dae61 )