AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Executor.pm  view on Meta::CPAN

package AI::MXNet::Executor;
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::Context;
use Mouse;
use AI::MXNet::Types;
use AI::MXNet::Function::Parameters;

has 'handle'            => (is => 'ro', isa => 'ExecutorHandle', required => 1);
has 'arg_arrays'        => (is => 'rw', isa => 'Maybe[ArrayRef[AI::MXNet::NDArray]]');
has 'grad_arrays'       => (is => 'rw', isa => 'Maybe[ArrayRef[Undef|AI::MXNet::NDArray]]'); 
has 'aux_arrays'        => (is => 'rw', isa => 'Maybe[ArrayRef[AI::MXNet::NDArray]]');
has '_symbol'           => (is => 'rw', init_arg => 'symbol',    isa => 'AI::MXNet::Symbol');
has '_ctx'              => (is => 'rw', init_arg => 'ctx',       isa => 'AI::MXNet::Context' );
has '_grad_req'         => (is => 'rw', init_arg => 'grad_req',  isa => 'Maybe[Str|ArrayRef[Str]|HashRef[Str]]');
has '_group2ctx'        => (is => 'rw', init_arg => 'group2ctx', isa => 'Maybe[HashRef[AI::MXNet::Context]]');
has '_monitor_callback' => (is => 'rw', isa => 'CodeRef');
has [qw/_arg_dict
        _grad_dict
        _aux_dict
        _output_dict
        outputs
        _output_dirty/] => (is => 'rw', init_arg => undef);
=head1 NAME

    AI::MXNet::Executor - The actual executing object of MXNet.

=head2 new

    Constructor, used by AI::MXNet::Symbol->bind and by AI::MXNet::Symbol->simple_bind.

    Parameters
    ----------
    handle: ExecutorHandle
        ExecutorHandle is generated by calling bind.

    See Also
    --------
    AI::MXNet::Symbol->bind : how to create the AI::MXNet::Executor.
=cut

sub BUILD
{
    my $self = shift;
    my ($symbol, $ctx, $grad_req, $group2ctx)
        =
    ($self->_symbol, $self->_ctx, $self->_grad_req, $self->_group2ctx);
    $symbol = $symbol->deepcopy;
    $ctx    = $ctx->deepcopy;
    if(ref $grad_req)
    {
        if(ref $grad_req eq 'ARRAY')
        {
            $grad_req = [ @{ $grad_req }];
        }
        elsif(ref $grad_req eq 'HASH')
        {
            $grad_req = { %{ $grad_req } };

        }
    }
    if(ref $group2ctx)
    {
        $group2ctx = { %{ $group2ctx } };
    }
    $self->_symbol($symbol);
    $self->_ctx($ctx);
    $self->_grad_req($grad_req);
    $self->_group2ctx($group2ctx);
    $self->outputs($self->_get_outputs);
}

sub DEMOLISH
{
    check_call(AI::MXNetCAPI::ExecutorFree(shift->handle));
}

# Get the dictionary given name and ndarray pairs.
func _get_dict(
    ArrayRef[Str]                       $names,
    ArrayRef[Maybe[AI::MXNet::NDArray]] $ndarrays
)
{
    my %nset = ();
    for my $nm (@{ $names })
    {
        if(exists $nset{ $nm })
        {
            confess("Duplicate names detected, @$names")
        }
        $nset{ $nm }++;
    }
    my %ret;
    @ret{ @{ $names } } = @{ $ndarrays };
    return \%ret;
}

=head2 outputs

lib/AI/MXNet/Executor.pm  view on Meta::CPAN

        >>> # doing forward by not specifying things, but copy to the executor before hand
        >>> $mydata->copyto($texec->arg_dict->{'data'});
        >>> $texec->forward(1);
        >>> # doing forward by specifying data and get outputs
        >>> my $outputs = $texec->forward(1, data => $mydata);
        >>> print $outputs->[0]->aspdl;
=cut

method forward(Int $is_train=0, %kwargs)
{
    if(%kwargs)
    {
        my $arg_dict = $self->arg_dict;
        while (my ($name, $array) = each %kwargs)
        {
            if(not find_type_constraint('AcceptableInput')->check($array))
            {
                confess('only accept keyword argument of NDArrays/PDLs/Perl Array refs');
            }
            if(not exists $arg_dict->{ $name })
            {
                confess("unknown argument $name");
            }
            if(not blessed($array) or not $array->isa('AI::MXNet::NDArray'))
            {
                $array = AI::MXNet::NDArray->array($array);
            }
            if(join(',', @{ $arg_dict->{$name}->shape }) ne join(',', @{ $array->shape }))
            {
                my $expected = $arg_dict->{$name}->shape;
                my $got = $array->shape;
                confess("Shape not match! Argument $name, need: @$expected, received: @$got'");
            }
            $arg_dict->{ $name } .= $array;
        }
    }
    check_call(AI::MXNetCAPI::ExecutorForward(
            $self->handle,
            $is_train
        )
    );
    if($self->_output_dirty)
    {
        AI::MXNet::Logging->warning(
            "Calling forward the second time after forward(is_train=1) "
            ."without calling backward first. Is this intended?"
        );
    }
    $self->_output_dirty($is_train);
    return $self->outputs;
}

=head2 backward

    Do a backward pass to get the gradient of the arguments.

    Parameters
    ----------
    out_grads : NDArray or an array ref of NDArrays or hash ref of NDArrays, optional.
        The gradient on the outputs to be propagated back.
        This parameter is only needed when bind is called
        on outputs that are not a loss function.
=cut

method backward(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]|HashRef[AI::MXNet::NDArray]] $out_grads=)
{
    $out_grads //= [];
    if(blessed $out_grads)
    {
        $out_grads = [$out_grads];
    }
    elsif(ref $out_grads eq 'HASH')
    {
        $out_grads = [ @{ $out_grads }{ @{ $self->symbol->list_outputs() } } ];
    }
    check_call(
        AI::MXNetCAPI::ExecutorBackward(
            $self->handle,
            scalar(@{ $out_grads }),
            [map { $_->handle } @{ $out_grads }]
        )
    );
    if(not $self->_output_dirty)
    {
        AI::MXNet::Logging->warning(
            "Calling backward without calling forward(is_train=True) "
            ."first. Behavior is undefined."
        );
    }
    $self->_output_dirty(0);
}

=head2 set_monitor_callback

    Install callback.

    Parameters
    ----------
    callback : subref
        Takes a string and an NDArrayHandle.
=cut

method set_monitor_callback(CodeRef $callback)
{
    $self->_monitor_callback($callback);
    check_call(
        AI::MXNetCAPI::ExecutorSetMonitorCallback(
            $self->handle,
            $self->_monitor_callback
        )
    );
}

=head2 arg_dict

    Get a hash ref representation of the argument arrays.

    Returns
    -------
    arg_dict : HashRef[AI::MXNet::NDArray]
        The map that maps a name of the arguments to the NDArrays.

lib/AI/MXNet/Executor.pm  view on Meta::CPAN

            }
            else
            {
                $new_arg_dict{ $name } = $arr->reshape($new_shape);
                if(defined $darr)
                {
                    $new_grad_dict{ $name } = $darr->reshape($new_shape);
                }
            }
        }
        else
        {
            confess(
                    "Shape of unspecified array arg:$name changed. "
                    ."This can cause the new executor to not share parameters "
                    ."with the old one. Please check for error in network."
                    ."If this is intended, set partial_shaping=True to suppress this warning."
            );
        }
        $i++;
    }
    my %new_aux_dict;
    $i = 0;
    for my $name (@{ $self->_symbol->list_auxiliary_states() })
    {
        my $new_shape = $aux_shapes->[$i];
        my $arr = $self->aux_arrays->[$i];
        if($partial_shaping or join(',', @{ $new_shape }) eq join (',', @{ $arr->shape }))
        {
            if(AI::MXNet::NDArray->size($new_shape) > $arr->size)
            {
                confess(
                    "New shape of arg:$name larger than original. "
                    ."First making a big executor and then down sizing it "
                    ."is more efficient than the reverse."
                    ."If you really want to up size, set \$allow_up_sizing=1 "
                    ."to enable allocation of new arrays."
                ) unless $allow_up_sizing;
                $new_aux_dict{ $name }  = AI::MXNet::NDArray->empty(
                    $new_shape,
                    ctx => $arr->context,
                    dtype => $arr->dtype
                );
            }
            else
            {
                $new_aux_dict{ $name } = $arr->reshape($new_shape);
            }
        }
        else
        {
            confess(
                "Shape of unspecified array aux:$name changed. "
                ."This can cause the new executor to not share parameters "
                ."with the old one. Please check for error in network."
                ."If this is intended, set partial_shaping=True to suppress this warning."
            );
        }
        $i++;
    }
    return $self->_symbol->bind(
                ctx         => $self->_ctx,
                args        => \%new_arg_dict,
                args_grad   => \%new_grad_dict,
                grad_req    => $self->_grad_req,
                aux_states  => \%new_aux_dict,
                group2ctx   => $self->_group2ctx,
                shared_exec => $self
    );
}

=head2 debug_str

    A debug string about the internal execution plan.

    Returns
    -------
    debug_str : string
        Debug string of the executor.
=cut

method debug_str()
{
    return scalar(check_call(AI::MXNetCAPI::ExecutorPrint($self->handle)));
}

1;



( run in 1.065 second using v1.01-cache-2.11-cpan-39bf76dae61 )