AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Module/Bucketing.pm view on Meta::CPAN
kvstore => $kv_store,
optimizer => $optimizer,
optimizer_params => {
learning_rate => $lr,
momentum => $mom,
wd => $wd,
},
initializer => mx->init->Xavier(factor_type => "in", magnitude => 2.34),
num_epoch => $num_epoch,
batch_end_callback => mx->callback->Speedometer($batch_size, $disp_batches),
($chkp_epoch ? (epoch_end_callback => mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch)) : ())
);
=head1 DESCRIPTION
Implements the AI::MXNet::Module::Base API, and allows multiple
symbols to be used depending on the `bucket_key` provided by each different
mini-batch of data
=cut
=head2 new
Parameters
----------
$sym_gen : subref or any perl object that overloads &{} op
A sub when called with a bucket key, returns a list with triple
of ($symbol, $data_names, $label_names).
$default_bucket_key : str or anything else
The key for the default bucket.
$logger : Logger
$context : AI::MXNet::Context or array ref of AI::MXNet::Context objects
Default is cpu(0)
$work_load_list : array ref of Num
Default is undef, indicating uniform workload.
$fixed_param_names: arrayref of str
Default is undef, indicating no network parameters are fixed.
$state_names : arrayref of str
states are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by set_states()
=cut
extends 'AI::MXNet::Module::Base';
has '_sym_gen' => (is => 'ro', init_arg => 'sym_gen', required => 1);
has '_default_bucket_key' => (is => 'rw', init_arg => 'default_bucket_key', required => 1);
has '_context' => (
is => 'ro', isa => 'AI::MXNet::Context|ArrayRef[AI::MXNet::Context]',
lazy => 1, default => sub { AI::MXNet::Context->cpu },
init_arg => 'context'
);
has '_work_load_list' => (is => 'rw', init_arg => 'work_load_list', isa => 'ArrayRef[Num]');
has '_curr_module' => (is => 'rw', init_arg => undef);
has '_curr_bucket_key' => (is => 'rw', init_arg => undef);
has '_buckets' => (is => 'rw', init_arg => undef, default => sub { +{} });
has '_fixed_param_names' => (is => 'rw', isa => 'ArrayRef[Str]', init_arg => 'fixed_param_names');
has '_state_names' => (is => 'rw', isa => 'ArrayRef[Str]', init_arg => 'state_names');
has '_params_dirty' => (is => 'rw', init_arg => undef);
sub BUILD
{
my ($self, $original_params) = @_;
$self->_fixed_param_names([]) unless defined $original_params->{fixed_param_names};
$self->_state_names([]) unless defined $original_params->{state_names};
$self->_params_dirty(0);
my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($self->_default_bucket_key);
$self->_check_input_names($symbol, $data_names//[], "data", 1);
$self->_check_input_names($symbol, $label_names//[], "label", 0);
$self->_check_input_names($symbol, $self->_state_names, "state", 1);
$self->_check_input_names($symbol, $self->_fixed_param_names, "fixed_param", 1);
}
method _reset_bind()
{
$self->binded(0);
$self->_buckets({});
$self->_curr_module(undef);
$self->_curr_bucket_key(undef);
}
method data_names()
{
if($self->binded)
{
return $self->_curr_module->data_names;
}
else
{
return (&{$self->_sym_gen}($self->_default_bucket_key))[1];
}
}
method output_names()
{
if($self->binded)
{
return $self->_curr_module->ouput_names;
}
else
{
my ($symbol) = &{$self->_sym_gen}($self->_default_bucket_key);
return $symbol->list_ouputs;
}
}
method data_shapes()
{
assert($self->binded);
return $self->_curr_module->data_shapes;
}
method label_shapes()
{
assert($self->binded);
return $self->_curr_module->label_shapes;
}
method output_shapes()
{
assert($self->binded);
return $self->_curr_module->output_shapes;
}
method get_params()
lib/AI/MXNet/Module/Bucketing.pm view on Meta::CPAN
assert($self->binded, 'call bind before switching bucket');
if(not exists $self->_buckets->{ $bucket_key })
{
my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($bucket_key);
my $module = AI::MXNet::Module->new(
symbol => $symbol,
data_names => $data_names,
label_names => $label_names,
logger => $self->logger,
context => $self->_context,
work_load_list => $self->_work_load_list
);
$module->bind(
data_shapes => $data_shapes,
label_shapes => $label_shapes,
for_training => $self->_curr_module->for_training,
inputs_need_grad => $self->_curr_module->inputs_need_grad,
force_rebind => 0,
shared_module => $self->_buckets->{ $self->_default_bucket_key },
);
$self->_buckets->{ $bucket_key } = $module;
}
$self->_curr_module($self->_buckets->{ $bucket_key });
$self->_curr_bucket_key($bucket_key);
}
method init_optimizer(
Str :$kvstore='local',
Optimizer :$optimizer='sgd',
HashRef :$optimizer_params={ learning_rate => 0.01 },
Bool :$force_init=0
)
{
assert($self->binded and $self->params_initialized);
if($self->optimizer_initialized and not $force_init)
{
$self->logger->warning('optimizer already initialized, ignoring.');
return;
}
$self->_curr_module->init_optimizer(
kvstore => $kvstore,
optimizer => $optimizer,
optimizer_params => $optimizer_params,
force_init => $force_init
);
for my $mod (values %{ $self->_buckets })
{
if($mod ne $self->_curr_module)
{
$mod->borrow_optimizer($self->_curr_module);
}
}
$self->optimizer_initialized(1);
}
method prepare(AI::MXNet::DataBatch $data_batch)
{
assert($self->binded and $self->params_initialized);
## perform bind if have not done so yet
my $original_bucket_key = $self->_curr_bucket_key;
$self->switch_bucket(
bucket_key => $data_batch->bucket_key,
data_shapes => $data_batch->provide_data,
label_shapes => $data_batch->provide_label
);
# switch back
$self->switch_bucket(bucket_key => $original_bucket_key);
}
method forward(
AI::MXNet::DataBatch $data_batch,
Bool :$is_train=
)
{
assert($self->binded and $self->params_initialized);
$self->switch_bucket(
bucket_key => $data_batch->bucket_key,
data_shapes => $data_batch->provide_data,
label_shapes => $data_batch->provide_label
);
$self->_curr_module->forward($data_batch, is_train => $is_train);
}
method backward(Maybe[ArrayRef[AI::MXNet::NDArray]|AI::MXNet::NDArray] $out_grads=)
{
assert($self->binded and $self->params_initialized);
$self->_curr_module->backward($out_grads);
}
method update()
{
assert($self->binded and $self->params_initialized and $self->optimizer_initialized);
$self->_params_dirty(1);
$self->_curr_module->update;
}
method get_outputs(Bool $merge_multi_context=1)
{
assert($self->binded and $self->params_initialized);
return $self->_curr_module->get_outputs($merge_multi_context);
}
method get_input_grads(Bool $merge_multi_context=1)
{
assert($self->binded and $self->params_initialized and $self->inputs_need_grad);
return $self->_curr_module->get_input_grads($merge_multi_context);
}
method update_metric(
AI::MXNet::EvalMetric $eval_metric,
ArrayRef[AI::MXNet::NDArray] $labels
)
{
assert($self->binded and $self->params_initialized);
$self->_curr_module->update_metric($eval_metric, $labels);
}
method symbol()
{
assert($self->binded);
return $self->_curr_module->symbol;
}
method get_symbol()
{
assert($self->binded);
return $self->_buckets->{ $self->_default_bucket_key }->symbol;
( run in 0.849 second using v1.01-cache-2.11-cpan-39bf76dae61 )