AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Module/Base.pm view on Meta::CPAN
package AI::MXNet::BatchEndParam;
use Mouse;
use AI::MXNet::Function::Parameters;
has [qw/epoch nbatch/] => (is => 'rw', isa => 'Int');
has 'eval_metric' => (is => 'rw', isa => 'AI::MXNet::EvalMetric');
package AI::MXNet::Module::Base;
use Mouse;
use AI::MXNet::Base;
use Time::HiRes qw(time);
=head1 NAME
AI::MXNet::Module::Base - Base class for AI::MXNet::Module and AI::MXNet::Module::Bucketing
=cut
func _as_list($obj)
{
return [$obj] if ((ref($obj)//'') ne 'ARRAY');
return $obj;
}
# Check that all input names are in symbol's argument
method _check_input_names(
AI::MXNet::Symbol $symbol,
ArrayRef[Str] $names,
Str $typename,
Bool $throw
)
{
my @candidates;
my %args = map {
push @candidates, $_ if not /_(?:weight|bias|gamma|beta)$/;
$_ => 1
} @{ $symbol->list_arguments };
for my $name (@$names)
{
my $msg;
if(not exists $args{$name} and $name ne 'softmax_label')
{
$msg = sprintf("\033[91mYou created Module with Module(..., %s_names=%s) but "
."input with name '%s' is not found in symbol.list_arguments(). "
."Did you mean one of:\n\t%s\033[0m",
$typename, "@$names", $name, join("\n\t", @candidates)
);
if($throw)
{
confess($msg);
}
else
{
AI::MXNet::Logging->warning($msg);
}
}
}
}
# Check that input names matches input data descriptors
method _check_names_match(
ArrayRef[Str] $data_names,
ArrayRef[NameShapeOrDataDesc] $data_shapes,
Str $name,
Bool $throw
)
{
return if (not @$data_shapes and @$data_names == 1 and $data_names->[0] eq 'softmax_label');
my @actual = map { @{$_}[0] } @{ $data_shapes };
if("@$data_names" ne "@actual")
{
my $msg = sprintf(
"Data provided by %s_shapes don't match names specified by %s_names (%s vs. %s)",
$name, $name, "@$data_shapes", "@$data_names"
);
if($throw)
{
confess($msg);
}
else
{
AI::MXNet::Logging->warning($msg);
}
}
}
method _parse_data_desc(
ArrayRef[Str] $data_names,
Maybe[ArrayRef[Str]] $label_names,
ArrayRef[NameShapeOrDataDesc] $data_shapes,
Maybe[ArrayRef[NameShapeOrDataDesc]] $label_shapes
)
{
$data_shapes = [map { blessed $_ ? $_ : AI::MXNet::DataDesc->new(@$_) } @$data_shapes];
$self->_check_names_match($data_names, $data_shapes, 'data', 1);
if($label_shapes)
{
$label_shapes = [map { blessed $_ ? $_ : AI::MXNet::DataDesc->new(@$_) } @$label_shapes];
$self->_check_names_match($label_names, $label_shapes, 'label', 0);
}
else
{
$self->_check_names_match($label_names, [], 'label', 0);
}
return ($data_shapes, $label_shapes);
}
lib/AI/MXNet/Module/Base.pm view on Meta::CPAN
{
$self->init_params(
initializer => undef,
arg_params => $arg_params,
aux_params => $aux_params,
allow_missing => $allow_missing,
force_init => $force_init,
allow_extra => $allow_extra
);
}
=head2 save_params
Save model parameters to file.
Parameters
----------
$fname : str
Path to output param file.
$arg_params= : Maybe[HashRef[AI::MXNet::NDArray]]
$aux_params= : Maybe[HashRef[AI::MXNet::NDArray]]
=cut
method save_params(
Str $fname,
Maybe[HashRef[AI::MXNet::NDArray]] $arg_params=,
Maybe[HashRef[AI::MXNet::NDArray]] $aux_params=
)
{
($arg_params, $aux_params) = $self->get_params
unless (defined $arg_params and defined $aux_params);
my %save_dict;
while(my ($k, $v) = each %{ $arg_params })
{
$save_dict{"arg:$k"} = $v->as_in_context(AI::MXNet::Context->cpu);
}
while(my ($k, $v) = each %{ $aux_params })
{
$save_dict{"aux:$k"} = $v->as_in_context(AI::MXNet::Context->cpu);
}
AI::MXNet::NDArray->save($fname, \%save_dict);
}
=head2 load_params
Load model parameters from file.
Parameters
----------
$fname : str
Path to input param file.
=cut
method load_params(Str $fname)
{
my %save_dict = %{ AI::MXNet::NDArray->load($fname) };
my %arg_params;
my %aux_params;
while(my ($k, $v) = each %save_dict)
{
my ($arg_type, $name) = split(/:/, $k, 2);
if($arg_type eq 'arg')
{
$arg_params{ $name } = $v;
}
elsif($arg_type eq 'aux')
{
$aux_params{ $name } = $v;
}
else
{
confess("Invalid param file $fname");
}
}
$self->set_params(\%arg_params, \%aux_params);
}
=head2 get_states
The states from all devices
Parameters
----------
$merge_multi_context=1 : Bool
Default is true (1). In the case when data-parallelism is used, the states
will be collected from multiple devices. A true value indicate that we
should merge the collected results so that they look like from a single
executor.
Returns
-------
If $merge_multi_context is 1, it is like [$out1, $out2]. Otherwise, it
is like [[$out1_dev1, $out1_dev2], [$out2_dev1, $out2_dev2]]. All the output
elements are AI::MXNet::NDArray.
=cut
method get_states(Bool $merge_multi_context=1)
{
assert($self->binded and $self->params_initialized);
assert(not $merge_multi_context);
return [];
}
=head2 set_states
Set value for states. You can specify either $states or $value, not both.
Parameters
----------
$states= : Maybe[ArrayRef[ArrayRef[AI::MXNet::NDArray]]]
source states arrays formatted like [[$state1_dev1, $state1_dev2],
[$state2_dev1, $state2_dev2]].
$value= : Maybe[Num]
a single scalar value for all state arrays.
=cut
method set_states(Maybe[ArrayRef[ArrayRef[AI::MXNet::NDArray]]] $states=, Maybe[Num] $value=)
{
assert($self->binded and $self->params_initialized);
assert(not $states and not $value);
}
=head2 install_monitor
Install monitor on all executors
lib/AI/MXNet/Module/Base.pm view on Meta::CPAN
:$inputs_need_grad=0 : Bool
Default is 0. Whether the gradients to the input data need to be computed.
Typically this is not needed. But this might be needed when implementing composition
of modules.
:$force_rebind=0 : Bool
Default is 0. This function does nothing if the executors are already
binded. But with this as 1, the executors will be forced to rebind.
:$shared_module= : A subclass of AI::MXNet::Module::Base
Default is undef. This is used in bucketing. When not undef, the shared module
essentially corresponds to a different bucket -- a module with different symbol
but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
:$grad_req='write' : Str|ArrayRef[Str]|HashRef[Str]
Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
(defaults to 'write').
Can be specified globally (str) or for each argument (array ref, hash ref).
=cut
method bind(
ArrayRef[AI::MXNet::DataDesc] $data_shapes,
Maybe[ArrayRef[AI::MXNet::DataDesc]] :$label_shapes=,
Bool :$for_training=1,
Bool :$inputs_need_grad=0,
Bool :$force_rebind=0,
Maybe[AI::MXNet::BaseModule] :$shared_module=,
Str|ArrayRef[Str]|HashRef[Str] :$grad_req='write'
)
{
confess("NotImplemented")
}
=head2 init_optimizer
Install and initialize optimizers.
Parameters
----------
:$kvstore='local' : str or KVStore
:$optimizer='sgd' : str or Optimizer
:$optimizer_params={ learning_rate => 0.01 } : hash ref
:$force_init=0 : Bool
=cut
method init_optimizer(
Str :$kvstore='local',
Optimizer :$optimizer='sgd',
HashRef :$optimizer_params={ learning_rate => 0.01 },
Bool :$force_init=0
)
{
confess("NotImplemented")
}
################################################################################
# misc
################################################################################
=head2 symbol
The symbol associated with this module.
Except for AI::MXNet::Module, for other types of modules (e.g. AI::MXNet::Module::Bucketing), this
property might not be a constant throughout its life time. Some modules might
not even be associated with any symbols.
=cut
method symbol()
{
return $self->_symbol;
}
1;
( run in 1.623 second using v1.01-cache-2.11-cpan-df04353d9ac )