AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Executor.pm  view on Meta::CPAN

package AI::MXNet::Executor;
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::Context;
use Mouse;
use AI::MXNet::Types;
use AI::MXNet::Function::Parameters;

has 'handle'            => (is => 'ro', isa => 'ExecutorHandle', required => 1);
has 'arg_arrays'        => (is => 'rw', isa => 'Maybe[ArrayRef[AI::MXNet::NDArray]]');
has 'grad_arrays'       => (is => 'rw', isa => 'Maybe[ArrayRef[Undef|AI::MXNet::NDArray]]'); 
has 'aux_arrays'        => (is => 'rw', isa => 'Maybe[ArrayRef[AI::MXNet::NDArray]]');
has '_symbol'           => (is => 'rw', init_arg => 'symbol',    isa => 'AI::MXNet::Symbol');
has '_ctx'              => (is => 'rw', init_arg => 'ctx',       isa => 'AI::MXNet::Context' );
has '_grad_req'         => (is => 'rw', init_arg => 'grad_req',  isa => 'Maybe[Str|ArrayRef[Str]|HashRef[Str]]');
has '_group2ctx'        => (is => 'rw', init_arg => 'group2ctx', isa => 'Maybe[HashRef[AI::MXNet::Context]]');
has '_monitor_callback' => (is => 'rw', isa => 'CodeRef');
has [qw/_arg_dict
        _grad_dict
        _aux_dict
        _output_dict
        outputs
        _output_dirty/] => (is => 'rw', init_arg => undef);
=head1 NAME

    AI::MXNet::Executor - The actual executing object of MXNet.

=head2 new

    Constructor, used by AI::MXNet::Symbol->bind and by AI::MXNet::Symbol->simple_bind.

    Parameters
    ----------
    handle: ExecutorHandle
        ExecutorHandle is generated by calling bind.

    See Also
    --------
    AI::MXNet::Symbol->bind : how to create the AI::MXNet::Executor.
=cut

sub BUILD
{
    my $self = shift;
    my ($symbol, $ctx, $grad_req, $group2ctx)
        =
    ($self->_symbol, $self->_ctx, $self->_grad_req, $self->_group2ctx);
    $symbol = $symbol->deepcopy;
    $ctx    = $ctx->deepcopy;
    if(ref $grad_req)
    {
        if(ref $grad_req eq 'ARRAY')
        {
            $grad_req = [ @{ $grad_req }];
        }
        elsif(ref $grad_req eq 'HASH')
        {
            $grad_req = { %{ $grad_req } };

        }
    }
    if(ref $group2ctx)
    {
        $group2ctx = { %{ $group2ctx } };
    }
    $self->_symbol($symbol);
    $self->_ctx($ctx);
    $self->_grad_req($grad_req);
    $self->_group2ctx($group2ctx);
    $self->outputs($self->_get_outputs);
}

sub DEMOLISH
{
    check_call(AI::MXNetCAPI::ExecutorFree(shift->handle));
}

# Get the dictionary given name and ndarray pairs.
func _get_dict(
    ArrayRef[Str]                       $names,
    ArrayRef[Maybe[AI::MXNet::NDArray]] $ndarrays
)
{
    my %nset = ();
    for my $nm (@{ $names })
    {
        if(exists $nset{ $nm })
        {
            confess("Duplicate names detected, @$names")
        }
        $nset{ $nm }++;
    }
    my %ret;
    @ret{ @{ $names } } = @{ $ndarrays };
    return \%ret;
}

=head2 outputs

    The output ndarrays bound to this executor.

    Returns
    -------
    An array ref with AI::MXNet::NDArray objects bound to the heads of the executor.
=cut

method _get_outputs()
{
    return [
            map {
                AI::MXNet::NDArray->new(handle => $_)
            }
            @{ check_call(AI::MXNetCAPI::ExecutorOutputs($self->handle)) }
    ];
}

=head2 forward

    Calculate the outputs specified by the bound symbol.

    Parameters
    ----------
    $is_train=0: bool, optional
        whether this forward is for evaluation purpose. If True,
        a backward call is expected to follow. Otherwise following
        backward is invalid.

    %kwargs
        Additional specification of input arguments.

    Examples
    --------
        >>> # doing forward by specifying data



( run in 0.510 second using v1.01-cache-2.11-cpan-39bf76dae61 )