AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Contrib/AutoGrad.pm  view on Meta::CPAN

    }

    my @ograd_handles;
    for my $arr (@$out_grads)
    {
        push @ograd_handles, (defined $arr ? $arr->handle : undef);
    }
    assert(
        (@ograd_handles == @output_handles),
        "outputs and out_grads must have the same length"
    );

    check_call(
        AI::MXNetCAPI::AutogradBackward(
            scalar(@output_handles),
            \@output_handles,
            \@ograd_handles,
            $retain_graph
        )
    );
}

=head2 compute_gradient

    Compute the gradients of outputs w.r.t variables.

    Parameters
    ----------
    outputs: array ref of NDArray

    Returns
    -------
    gradients: array ref of NDArray
=cut


method compute_gradient(ArrayRef[AI::MXNet::NDArray] $outputs)
{
    __PACKAGE__->backward($outputs);
}

=head2 grad_and_loss

    Return function that computes both gradient of arguments and loss value.

    Parameters
    ----------
    func: a perl sub
        The forward (loss) function.
    argnum: an int or a array ref of int
        The index of argument to calculate gradient for.

    Returns
    -------
    grad_and_loss_func: a perl sub
        A function that would compute both the gradient of arguments and loss value.
=cut

method grad_and_loss(CodeRef $func, Maybe[Int|ArrayRef[Int]] $argnum=)
{
    return sub {
        my @args = @_;
        my @variables = @_;
        if(defined $argnum)
        {
            my @argnum = ref $argnum ? @$argnum : ($argnum);
            @variables = map { $_[$_] } @argnum;
        }
        map {
            assert(
                (blessed($_) and $_->isa('AI::MXNet::NDArray')),
                "type of autograd input should NDArray")
        } @variables;
        my @grads = map { $_->zeros_like } @variables;
        __PACKAGE__->mark_variables(\@variables, \@grads);
        my $prev = __PACKAGE__->set_is_training(1);
        my $outputs = $func->(@args);
        __PACKAGE__->set_is_training(0) unless $prev;
        __PACKAGE__->compute_gradient(ref $outputs eq 'ARRAY' ? $outputs : [$outputs]);
        return (\@grads, $outputs);
    };
}

=head2 grad

    Return function that computes gradient of arguments.

    Parameters
    ----------
    func: a perl sub
        The forward (loss) function.
    argnum: an int or arry ref of int
        The index of argument to calculate gradient for.

    Returns
    -------
    grad_func: a perl function
        A function that would compute the gradient of arguments.
=cut


method grad(CodeRef $func, Maybe[Int|ArrayRef[Int]] $argnum=)
{
    my $grad_with_loss_func = __PACKAGE__->grad_and_loss($func, $argnum);
    return sub {
        return ($grad_with_loss_func->(@_))[0];
    };
}

method train_section(CodeRef $sub)
{
    my $prev = __PACKAGE__->set_is_training(1);
    $sub->();
    __PACKAGE__->set_is_training(0) unless $prev;
}

method test_section(CodeRef $sub)
{
    my $prev = __PACKAGE__->set_is_training(0);
    $sub->();
    __PACKAGE__->set_is_training(1) if $prev;
}

1;



( run in 0.486 second using v1.01-cache-2.11-cpan-39bf76dae61 )