AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Executor.pm  view on Meta::CPAN

=head2 outputs

    The output ndarrays bound to this executor.

    Returns
    -------
    An array ref with AI::MXNet::NDArray objects bound to the heads of the executor.
=cut

method _get_outputs()
{
    return [
            map {
                AI::MXNet::NDArray->new(handle => $_)
            }
            @{ check_call(AI::MXNetCAPI::ExecutorOutputs($self->handle)) }
    ];
}

=head2 forward

    Calculate the outputs specified by the bound symbol.

    Parameters
    ----------
    $is_train=0: bool, optional
        whether this forward is for evaluation purpose. If True,
        a backward call is expected to follow. Otherwise following
        backward is invalid.

    %kwargs
        Additional specification of input arguments.

    Examples
    --------
        >>> # doing forward by specifying data
        >>> $texec->forward(1, data => $mydata);
        >>> # doing forward by not specifying things, but copy to the executor before hand
        >>> $mydata->copyto($texec->arg_dict->{'data'});
        >>> $texec->forward(1);
        >>> # doing forward by specifying data and get outputs
        >>> my $outputs = $texec->forward(1, data => $mydata);
        >>> print $outputs->[0]->aspdl;
=cut

method forward(Int $is_train=0, %kwargs)
{
    if(%kwargs)
    {
        my $arg_dict = $self->arg_dict;
        while (my ($name, $array) = each %kwargs)
        {
            if(not find_type_constraint('AcceptableInput')->check($array))
            {
                confess('only accept keyword argument of NDArrays/PDLs/Perl Array refs');
            }
            if(not exists $arg_dict->{ $name })
            {
                confess("unknown argument $name");
            }
            if(not blessed($array) or not $array->isa('AI::MXNet::NDArray'))
            {
                $array = AI::MXNet::NDArray->array($array);
            }
            if(join(',', @{ $arg_dict->{$name}->shape }) ne join(',', @{ $array->shape }))
            {
                my $expected = $arg_dict->{$name}->shape;
                my $got = $array->shape;
                confess("Shape not match! Argument $name, need: @$expected, received: @$got'");
            }
            $arg_dict->{ $name } .= $array;
        }
    }
    check_call(AI::MXNetCAPI::ExecutorForward(
            $self->handle,
            $is_train
        )
    );
    if($self->_output_dirty)
    {
        AI::MXNet::Logging->warning(
            "Calling forward the second time after forward(is_train=1) "
            ."without calling backward first. Is this intended?"
        );
    }
    $self->_output_dirty($is_train);
    return $self->outputs;
}

=head2 backward

    Do a backward pass to get the gradient of the arguments.

    Parameters
    ----------
    out_grads : NDArray or an array ref of NDArrays or hash ref of NDArrays, optional.
        The gradient on the outputs to be propagated back.
        This parameter is only needed when bind is called
        on outputs that are not a loss function.
=cut

method backward(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]|HashRef[AI::MXNet::NDArray]] $out_grads=)
{
    $out_grads //= [];
    if(blessed $out_grads)
    {
        $out_grads = [$out_grads];
    }
    elsif(ref $out_grads eq 'HASH')
    {
        $out_grads = [ @{ $out_grads }{ @{ $self->symbol->list_outputs() } } ];
    }
    check_call(
        AI::MXNetCAPI::ExecutorBackward(
            $self->handle,
            scalar(@{ $out_grads }),
            [map { $_->handle } @{ $out_grads }]
        )
    );
    if(not $self->_output_dirty)
    {
        AI::MXNet::Logging->warning(
            "Calling backward without calling forward(is_train=True) "
            ."first. Behavior is undefined."
        );
    }
    $self->_output_dirty(0);
}

=head2 set_monitor_callback

    Install callback.

    Parameters
    ----------
    callback : subref
        Takes a string and an NDArrayHandle.
=cut

method set_monitor_callback(CodeRef $callback)
{
    $self->_monitor_callback($callback);
    check_call(
        AI::MXNetCAPI::ExecutorSetMonitorCallback(
            $self->handle,
            $self->_monitor_callback
        )
    );
}

=head2 arg_dict

    Get a hash ref representation of the argument arrays.

    Returns
    -------
    arg_dict : HashRef[AI::MXNet::NDArray]
        The map that maps a name of the arguments to the NDArrays.
=cut

method arg_dict()
{
    if(not defined $self->_arg_dict)
    {
        $self->_arg_dict(_get_dict(



( run in 0.753 second using v1.01-cache-2.11-cpan-39bf76dae61 )