AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Module.pm view on Meta::CPAN
## TODO
## this class is here because of https://github.com/gfx/p5-Mouse/pull/67
## once 2.4.7 version of Mouse in Ubuntu for affected Perl version
## these accessors should be merged into main class
package AI::MXNet::Module::Private;
use Mouse;
has [qw/_param_names _fixed_param_names
_aux_names _data_names _label_names _state_names
_output_names _arg_params _aux_params
_params_dirty _optimizer _kvstore
_update_on_kvstore _updater _work_load_list
_preload_opt_states _exec_group
_data_shapes _label_shapes _context _grad_req/
] => (is => 'rw', init_arg => undef);
package AI::MXNet::Module;
use AI::MXNet::Base;
use AI::MXNet::Function::Parameters;
use List::Util qw(max);
use Data::Dumper ();
use Mouse;
func _create_kvstore(
Maybe[Str|AI::MXNet::KVStore] $kvstore,
Int $num_device,
HashRef[AI::MXNet::NDArray] $arg_params
)
{
my $update_on_kvstore = 1;
my $kv;
if(defined $kvstore)
{
if(blessed $kvstore)
{
$kv = $kvstore;
}
else
{
# create kvstore using the string type
if($num_device == 1 and $kvstore !~ /dist/)
{
# no need to use kv for single device and single machine
}
else
{
$kv = AI::MXNet::KVStore->create($kvstore);
if($kvstore eq 'local')
{
# automatically select a proper local
my $max_size = max(map { product(@{ $_->shape }) } values %{ $arg_params });
if($max_size > 1024 * 1024 * 16)
{
$update_on_kvstore = 0;
}
}
}
}
}
$update_on_kvstore = 0 if not $kv;
return ($kv, $update_on_kvstore);
}
func _initialize_kvstore(
AI::MXNet::KVStore :$kvstore,
HashRef[AI::MXNet::NDArray] :$arg_params,
ArrayRef[Str] :$param_names,
Bool :$update_on_kvstore,
ArrayRef[AI::MXNet::NDArray]|ArrayRef[ArrayRef[AI::MXNet::NDArray]] :$param_arrays
)
{
enumerate(sub{
my ($idx, $param_on_devs) = @_;
my $name = $param_names->[$idx];
$kvstore->init($name, $arg_params->{ $name });
if($update_on_kvstore)
{
$kvstore->pull($name, out => $param_on_devs, priority => -$idx);
}
}, $param_arrays);
lib/AI/MXNet/Module.pm view on Meta::CPAN
$self->_p->_update_on_kvstore($update_on_kvstore);
$self->_p->_updater(undef);
if($kvstore)
{
# copy initialized local parameters to kvstore
_initialize_kvstore(
kvstore => $kvstore,
param_arrays => $self->_p->_exec_group->_p->param_arrays,
arg_params => $self->_p->_arg_params,
param_names => $self->_p->_param_names,
update_on_kvstore => $update_on_kvstore
);
}
if($update_on_kvstore)
{
$kvstore->set_optimizer($self->_p->_optimizer);
}
else
{
$self->_p->_updater(AI::MXNet::Optimizer->get_updater($optimizer));
}
$self->optimizer_initialized(1);
if($self->_p->_preload_opt_states)
{
$self->load_optimizer_states($self->_p->_preload_opt_states);
$self->_p->_preload_opt_states(undef);
}
}
=head2 borrow_optimizer
Borrow optimizer from a shared module. Used in bucketing, where exactly the same
optimizer (esp. kvstore) is used.
Parameters
----------
shared_module : AI::MXNet::Module
=cut
method borrow_optimizer(AI::MXNet::Module $shared_module)
{
assert($shared_module->optimizer_initialized);
$self->_p->_optimizer($shared_module->_p->_optimizer);
$self->_p->_kvstore($shared_module->_p->_kvstore);
$self->_p->_update_on_kvstore($shared_module->_p->_update_on_kvstore);
$self->_p->_updater($shared_module->_p->_updater);
$self->optimizer_initialized(1);
}
method forward(
AI::MXNet::DataBatch $data_batch,
Maybe[Bool] :$is_train=
)
{
assert($self->binded and $self->params_initialized);
my @curr_data_shapes = map { $_->shape } @{ $self->data_shapes };
my @new_data_shapes = map { $_->shape } @{ $data_batch->data };
if(Data::Dumper->Dump(\@curr_data_shapes) ne Data::Dumper->Dump(\@new_data_shapes))
{
my $new_dshape;
if($data_batch->can('provide_data') and $data_batch->provide_data)
{
$new_dshape = $data_batch->provide_data;
}
else
{
$new_dshape = [];
zip(sub {
my ($i, $shape) = @_;
push @{ $new_dshape }, AI::MXNet::DataDesc->new(
$i->name, $shape, $i->dtype, $i->layout
);
}, $self->data_shapes, \@new_data_shapes);
}
my $new_lshape;
if($data_batch->can('provide_label') and $data_batch->provide_label)
{
$new_lshape = $data_batch->provide_label;
}
elsif($data_batch->can('label') and $data_batch->label)
{
$new_lshape = [];
zip(sub {
my ($i, $j) = @_;
push @{ $new_lshape }, AI::MXNet::DataDesc->new(
$i->name, $j->shape, $i->dtype, $i->layout
);
}, $self->label_shapes, $data_batch->label);
}
$self->reshape(data_shapes => $new_dshape, label_shapes => $new_lshape);
}
$self->_p->_exec_group->forward($data_batch, $is_train);
}
method backward(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $out_grads=)
{
assert($self->binded and $self->params_initialized);
$self->_p->_exec_group->backward($out_grads);
}
method update()
{
assert($self->binded and $self->params_initialized and $self->optimizer_initialized);
$self->_p->_params_dirty(1);
if($self->_p->_update_on_kvstore)
{
_update_params_on_kvstore(
$self->_p->_exec_group->_p->param_arrays,
$self->_p->_exec_group->_p->grad_arrays,
$self->_p->_kvstore,
$self->_p->_exec_group->param_names
);
}
else
{
_update_params(
$self->_p->_exec_group->_p->param_arrays,
$self->_p->_exec_group->_p->grad_arrays,
( run in 0.233 second using v1.01-cache-2.11-cpan-eab888a1d7d )