AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Executor.pm view on Meta::CPAN
package AI::MXNet::Executor;
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::Context;
use Mouse;
use AI::MXNet::Types;
use AI::MXNet::Function::Parameters;
has 'handle' => (is => 'ro', isa => 'ExecutorHandle', required => 1);
has 'arg_arrays' => (is => 'rw', isa => 'Maybe[ArrayRef[AI::MXNet::NDArray]]');
has 'grad_arrays' => (is => 'rw', isa => 'Maybe[ArrayRef[Undef|AI::MXNet::NDArray]]');
has 'aux_arrays' => (is => 'rw', isa => 'Maybe[ArrayRef[AI::MXNet::NDArray]]');
has '_symbol' => (is => 'rw', init_arg => 'symbol', isa => 'AI::MXNet::Symbol');
has '_ctx' => (is => 'rw', init_arg => 'ctx', isa => 'AI::MXNet::Context' );
has '_grad_req' => (is => 'rw', init_arg => 'grad_req', isa => 'Maybe[Str|ArrayRef[Str]|HashRef[Str]]');
has '_group2ctx' => (is => 'rw', init_arg => 'group2ctx', isa => 'Maybe[HashRef[AI::MXNet::Context]]');
has '_monitor_callback' => (is => 'rw', isa => 'CodeRef');
has [qw/_arg_dict
_grad_dict
_aux_dict
_output_dict
outputs
_output_dirty/] => (is => 'rw', init_arg => undef);
=head1 NAME
AI::MXNet::Executor - The actual executing object of MXNet.
=head2 new
Constructor, used by AI::MXNet::Symbol->bind and by AI::MXNet::Symbol->simple_bind.
Parameters
----------
handle: ExecutorHandle
ExecutorHandle is generated by calling bind.
See Also
--------
AI::MXNet::Symbol->bind : how to create the AI::MXNet::Executor.
=cut
sub BUILD
{
my $self = shift;
my ($symbol, $ctx, $grad_req, $group2ctx)
=
($self->_symbol, $self->_ctx, $self->_grad_req, $self->_group2ctx);
$symbol = $symbol->deepcopy;
$ctx = $ctx->deepcopy;
if(ref $grad_req)
{
if(ref $grad_req eq 'ARRAY')
{
$grad_req = [ @{ $grad_req }];
}
elsif(ref $grad_req eq 'HASH')
{
$grad_req = { %{ $grad_req } };
}
}
if(ref $group2ctx)
{
$group2ctx = { %{ $group2ctx } };
}
$self->_symbol($symbol);
$self->_ctx($ctx);
$self->_grad_req($grad_req);
$self->_group2ctx($group2ctx);
$self->outputs($self->_get_outputs);
}
sub DEMOLISH
{
check_call(AI::MXNetCAPI::ExecutorFree(shift->handle));
}
# Get the dictionary given name and ndarray pairs.
func _get_dict(
ArrayRef[Str] $names,
ArrayRef[Maybe[AI::MXNet::NDArray]] $ndarrays
)
{
my %nset = ();
for my $nm (@{ $names })
{
if(exists $nset{ $nm })
{
confess("Duplicate names detected, @$names")
}
$nset{ $nm }++;
}
my %ret;
@ret{ @{ $names } } = @{ $ndarrays };
return \%ret;
}
=head2 outputs
The output ndarrays bound to this executor.
Returns
-------
An array ref with AI::MXNet::NDArray objects bound to the heads of the executor.
=cut
method _get_outputs()
{
return [
map {
AI::MXNet::NDArray->new(handle => $_)
}
@{ check_call(AI::MXNetCAPI::ExecutorOutputs($self->handle)) }
];
}
=head2 forward
Calculate the outputs specified by the bound symbol.
Parameters
----------
$is_train=0: bool, optional
whether this forward is for evaluation purpose. If True,
a backward call is expected to follow. Otherwise following
backward is invalid.
%kwargs
Additional specification of input arguments.
Examples
--------
>>> # doing forward by specifying data
lib/AI/MXNet/Executor.pm view on Meta::CPAN
$is_train
)
);
if($self->_output_dirty)
{
AI::MXNet::Logging->warning(
"Calling forward the second time after forward(is_train=1) "
."without calling backward first. Is this intended?"
);
}
$self->_output_dirty($is_train);
return $self->outputs;
}
=head2 backward
Do a backward pass to get the gradient of the arguments.
Parameters
----------
out_grads : NDArray or an array ref of NDArrays or hash ref of NDArrays, optional.
The gradient on the outputs to be propagated back.
This parameter is only needed when bind is called
on outputs that are not a loss function.
=cut
method backward(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]|HashRef[AI::MXNet::NDArray]] $out_grads=)
{
$out_grads //= [];
if(blessed $out_grads)
{
$out_grads = [$out_grads];
}
elsif(ref $out_grads eq 'HASH')
{
$out_grads = [ @{ $out_grads }{ @{ $self->symbol->list_outputs() } } ];
}
check_call(
AI::MXNetCAPI::ExecutorBackward(
$self->handle,
scalar(@{ $out_grads }),
[map { $_->handle } @{ $out_grads }]
)
);
if(not $self->_output_dirty)
{
AI::MXNet::Logging->warning(
"Calling backward without calling forward(is_train=True) "
."first. Behavior is undefined."
);
}
$self->_output_dirty(0);
}
=head2 set_monitor_callback
Install callback.
Parameters
----------
callback : subref
Takes a string and an NDArrayHandle.
=cut
method set_monitor_callback(CodeRef $callback)
{
$self->_monitor_callback($callback);
check_call(
AI::MXNetCAPI::ExecutorSetMonitorCallback(
$self->handle,
$self->_monitor_callback
)
);
}
=head2 arg_dict
Get a hash ref representation of the argument arrays.
Returns
-------
arg_dict : HashRef[AI::MXNet::NDArray]
The map that maps a name of the arguments to the NDArrays.
=cut
method arg_dict()
{
if(not defined $self->_arg_dict)
{
$self->_arg_dict(_get_dict(
$self->_symbol->list_arguments(),
$self->arg_arrays
)
);
}
return $self->_arg_dict;
}
=head2 grad_dict
Get a hash ref representation of the gradient arrays.
Returns
-------
grad_dict : HashRef[AI::MXNet::NDArray]
The map that maps a name of the arguments to the gradient NDArrays.
=cut
method grad_dict()
{
if(not defined $self->_grad_dict)
{
$self->_grad_dict(_get_dict(
$self->_symbol->list_arguments(),
$self->grad_arrays
)
);
}
return $self->_grad_dict;
}
( run in 1.083 second using v1.01-cache-2.11-cpan-39bf76dae61 )