AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Module/Bucketing.pm view on Meta::CPAN
Str|ArrayRef[Str]|HashRef[Str] :$grad_req='write',
Maybe[Str] :$bucket_key=
)
{
# in case we already initialized params, keep it
my ($arg_params, $aux_params);
if($self->params_initialized)
{
($arg_params, $aux_params) = $self->get_params;
}
# force rebinding is typically used when one want to switch from
# training to prediction phase.
$self->_reset_bind if $force_rebind;
if($self->binded)
{
$self->logger->warning('Already binded, ignoring bind()');
return;
}
assert((not defined $shared_module), 'shared_module for BucketingModule is not supported');
$self->for_training($for_training);
$self->inputs_need_grad($inputs_need_grad);
$self->binded(1);
my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($bucket_key//$self->_default_bucket_key);
my $module = AI::MXNet::Module->new(
symbol => $symbol,
data_names => $data_names,
label_names => $label_names,
logger => $self->logger,
context => $self->_context,
work_load_list => $self->_work_load_list,
state_names => $self->_state_names,
fixed_param_names => $self->_fixed_param_names
);
$module->bind(
data_shapes => $data_shapes,
label_shapes => $label_shapes,
for_training => $for_training,
inputs_need_grad => $inputs_need_grad,
force_rebind => 0,
shared_module => undef,
grad_req => $grad_req
);
$self->_curr_module($module);
$self->_curr_bucket_key($self->_default_bucket_key);
$self->_buckets->{ $self->_default_bucket_key } = $module;
# copy back saved params, if already initialized
if($self->params_initialized)
{
$self->set_params($arg_params, $aux_params);
}
}
=head2 switch_bucket
Switch to a different bucket. This will change $self->_curr_module.
Parameters
----------
:$bucket_key : str (or any perl object that overloads "" op)
The key of the target bucket.
:$data_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
Typically $data_batch->provide_data.
:$label_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
Typically $data_batch->provide_label.
=cut
method switch_bucket(
Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] :$data_shapes=,
Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] :$label_shapes=,
:$bucket_key
)
{
assert($self->binded, 'call bind before switching bucket');
if(not exists $self->_buckets->{ $bucket_key })
{
my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($bucket_key);
my $module = AI::MXNet::Module->new(
symbol => $symbol,
data_names => $data_names,
label_names => $label_names,
logger => $self->logger,
context => $self->_context,
work_load_list => $self->_work_load_list
);
$module->bind(
data_shapes => $data_shapes,
label_shapes => $label_shapes,
for_training => $self->_curr_module->for_training,
inputs_need_grad => $self->_curr_module->inputs_need_grad,
force_rebind => 0,
shared_module => $self->_buckets->{ $self->_default_bucket_key },
);
$self->_buckets->{ $bucket_key } = $module;
}
$self->_curr_module($self->_buckets->{ $bucket_key });
$self->_curr_bucket_key($bucket_key);
}
method init_optimizer(
Str :$kvstore='local',
Optimizer :$optimizer='sgd',
HashRef :$optimizer_params={ learning_rate => 0.01 },
Bool :$force_init=0
)
{
assert($self->binded and $self->params_initialized);
if($self->optimizer_initialized and not $force_init)
{
$self->logger->warning('optimizer already initialized, ignoring.');
return;
}
$self->_curr_module->init_optimizer(
kvstore => $kvstore,
optimizer => $optimizer,
( run in 0.435 second using v1.01-cache-2.11-cpan-39bf76dae61 )