AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Symbol.pm  view on Meta::CPAN

    my @arg_arrays  = map { AI::MXNet::NDArray->new(handle => $_) } @{ $in_arg_handles };
    my @grad_arrays = map { defined $_ ? AI::MXNet::NDArray->new(handle => $_) : undef  } @{ $arg_grad_handles };
    my @aux_arrays  = map { AI::MXNet::NDArray->new(handle => $_) } @{ $aux_state_handles };
    my $executor = AI::MXNet::Executor->new(
        handle    => $exe_handle,
        symbol    => $self,
        ctx       => $ctx,
        grad_req  => $grad_req,
        group2ctx => $group2ctx
    );
    $executor->arg_arrays(\@arg_arrays);
    $executor->grad_arrays(\@grad_arrays);
    $executor->aux_arrays(\@aux_arrays);
    return $executor;
}

=head2 bind

    Bind current symbol to get an executor.

    Parameters
    ----------
    :$ctx : AI::MXNet::Context
        The device context the generated executor to run on.

    :$args : HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]
        Input arguments to the symbol.
            - If type is array ref of NDArray, the position is in the same order of list_arguments.
            - If type is hash ref of str to NDArray, then it maps the name of arguments
                to the corresponding NDArray.
            - In either case, all the arguments must be provided.

    :$args_grad : Maybe[HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]]
        When specified, args_grad provide NDArrays to hold
        the result of gradient value in backward.
            - If type is array ref of NDArray, the position is in the same order of list_arguments.
            - If type is hash ref of str to NDArray, then it maps the name of arguments
                to the corresponding NDArray.
            - When the type is hash ref of str to NDArray, users only need to provide the dict
                for needed argument gradient.
        Only the specified argument gradient will be calculated.

    :$grad_req : {'write', 'add', 'null'}, or array ref of str or hash ref of str to str, optional
        Specifies how we should update the gradient to the args_grad.
            - 'write' means everytime gradient is write to specified args_grad NDArray.
            - 'add' means everytime gradient is add to the specified NDArray.
            - 'null' means no action is taken, the gradient may not be calculated.

    :$aux_states : array ref of NDArray, or hash ref of str to NDArray, optional
        Input auxiliary states to the symbol, only need to specify when
        list_auxiliary_states is not empty.
            - If type is array ref of NDArray, the position is in the same order of list_auxiliary_states
            - If type is hash ref of str to NDArray, then it maps the name of auxiliary_states
                to the corresponding NDArray,
            - In either case, all the auxiliary_states need to be provided.

    :$group2ctx : hash ref of string to AI::MXNet::Context
        The mapping of the ctx_group attribute to the context assignment.

    :$shared_exec : AI::MXNet::Executor
        Executor to share memory with. This is intended for runtime reshaping, variable length
        sequences, etc. The returned executor shares state with shared_exec, and should not be
        used in parallel with it.

    Returns
    -------
    $executor : AI::MXNet::Executor
        The generated Executor

    Notes
    -----
    Auxiliary states are special states of symbols that do not corresponds to an argument,
    and do not have gradient. But still be useful for the specific operations.
    A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm.
    Most operators do not have auxiliary states and this parameter can be safely ignored.

    User can give up gradient by using a hash ref in args_grad and only specify
    the gradient they're interested in.
=cut

method bind(
        AI::MXNet::Context                                              :$ctx,
        HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]        :$args,
        Maybe[HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]] :$args_grad=,
        Str|HashRef[Str]|ArrayRef[Str]                                  :$grad_req='write',
        Maybe[HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]] :$aux_states=,
        Maybe[HashRef[AI::MXNet::Context]]                              :$group2ctx=,
        Maybe[AI::MXNet::Executor]                                      :$shared_exec=
)
{
    $grad_req //= 'write';
    my $listed_arguments = $self->list_arguments();
    my ($args_handle, $args_grad_handle, $aux_args_handle) = ([], [], []);
    ($args_handle, $args) = $self->_get_ndarray_inputs('args', $args, $listed_arguments);
    if(not defined $args_grad)
    {
        @$args_grad_handle = ((undef) x (@$args));
    }
    else
    {
        ($args_grad_handle, $args_grad) = $self->_get_ndarray_inputs(
                'args_grad', $args_grad, $listed_arguments, 1
        );
    }

    if(not defined $aux_states)
    {
        $aux_states = [];
    }
    ($aux_args_handle, $aux_states) = $self->_get_ndarray_inputs(
            'aux_states', $aux_states, $self->list_auxiliary_states()
    );

    # setup requirements
    my $req_map = { null => 0, write => 1, add =>  3 };
    my $req_array = [];
    if(not ref $grad_req)
    {
        confess('grad_req must be one of "null,write,add"')
            unless exists $req_map->{ $grad_req };
        @{ $req_array } = (($req_map->{ $grad_req }) x @{ $listed_arguments });



( run in 0.576 second using v1.01-cache-2.11-cpan-39bf76dae61 )