AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Contrib/AutoGrad.pm  view on Meta::CPAN

package AI::MXNet::Contrib::AutoGrad;
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::Function::Parameters;
use Scalar::Util qw(blessed);

=head1 NAME

    AI::MXNet::AutoGrad - Autograd for NDArray.
=cut

=head2 set_is_training

    Set status to training/not training. When training, graph will be constructed
    for gradient computation. Operators will also run with ctx.is_train=True. For example,
    Dropout will drop inputs randomly when is_train=True while simply passing through
    if is_train=False.

    Parameters
    ----------
    is_train: bool

    Returns
    -------
    previous state before this set.
=cut


method set_is_training(Bool $is_train)
{
    my $prev = scalar(check_call(AI::MXNetCAPI::AutogradSetIsTraining($is_train ? 1 : 0)));
    return $prev ? 1 : 0
}

=head2 mark_variables

    Mark AI::MXNet::NDArrays as variables to compute gradient for autograd.

    Parameters
    ----------
    variables: array ref of AI::MXNet::NDArrays
    gradients: array ref of AI::MXNet::NDArrays
    grad_reqs: array ref of strings
=cut

method mark_variables(
    ArrayRef[AI::MXNet::NDArray]  $variables,
    ArrayRef[AI::MXNet::NDArray]  $gradients,
    GradReq|ArrayRef[GradReq]     $grad_reqs='write'
)
{
    my @variable_handles = map { $_->handle } @{ $variables };
    my @gradient_handles = map { $_->handle } @{ $gradients };
    my @grad_reqs;
    if(not ref $grad_reqs)
    {
        @grad_reqs = (GRAD_REQ_MAP->{ $grad_reqs }) x scalar(@variable_handles);
    }
    else
    {
        @grad_reqs = map { GRAD_REQ_MAP->{ $_ } } @{ $grad_reqs };
    }
    check_call(
        AI::MXNetCAPI::AutogradMarkVariables(
            scalar(@variable_handles),
            \@variable_handles,
            \@grad_reqs,
            \@gradient_handles
        )
    );
}

=head2 backward

     Compute the gradients of outputs w.r.t variables.

     Parameters
     ----------
     outputs: array ref of NDArray
     out_grads: array ref of NDArray or undef
     retain_graph: bool, defaults to false
=cut


method backward(
    ArrayRef[AI::MXNet::NDArray] $outputs,
    Maybe[ArrayRef[AI::MXNet::NDArray|Undef]] $out_grads=,
    Bool $retain_graph=0
)
{
    my @output_handles = map { $_->handle } @{ $outputs };
    if(not defined $out_grads)
    {
        check_call(
            AI::MXNetCAPI::AutogradBackward(
                scalar(@output_handles),
                \@output_handles,
                [],
                $retain_graph
            )
        );
        return;
    }

    my @ograd_handles;
    for my $arr (@$out_grads)
    {
        push @ograd_handles, (defined $arr ? $arr->handle : undef);
    }
    assert(
        (@ograd_handles == @output_handles),
        "outputs and out_grads must have the same length"
    );

    check_call(
        AI::MXNetCAPI::AutogradBackward(
            scalar(@output_handles),
            \@output_handles,
            \@ograd_handles,
            $retain_graph
        )
    );
}

=head2 compute_gradient

    Compute the gradients of outputs w.r.t variables.

    Parameters
    ----------
    outputs: array ref of NDArray

    Returns
    -------
    gradients: array ref of NDArray
=cut


method compute_gradient(ArrayRef[AI::MXNet::NDArray] $outputs)
{
    __PACKAGE__->backward($outputs);
}

=head2 grad_and_loss

    Return function that computes both gradient of arguments and loss value.

    Parameters
    ----------
    func: a perl sub
        The forward (loss) function.
    argnum: an int or a array ref of int
        The index of argument to calculate gradient for.

    Returns
    -------
    grad_and_loss_func: a perl sub
        A function that would compute both the gradient of arguments and loss value.
=cut

method grad_and_loss(CodeRef $func, Maybe[Int|ArrayRef[Int]] $argnum=)
{
    return sub {
        my @args = @_;
        my @variables = @_;
        if(defined $argnum)
        {
            my @argnum = ref $argnum ? @$argnum : ($argnum);
            @variables = map { $_[$_] } @argnum;
        }
        map {
            assert(
                (blessed($_) and $_->isa('AI::MXNet::NDArray')),
                "type of autograd input should NDArray")
        } @variables;
        my @grads = map { $_->zeros_like } @variables;
        __PACKAGE__->mark_variables(\@variables, \@grads);
        my $prev = __PACKAGE__->set_is_training(1);
        my $outputs = $func->(@args);
        __PACKAGE__->set_is_training(0) unless $prev;
        __PACKAGE__->compute_gradient(ref $outputs eq 'ARRAY' ? $outputs : [$outputs]);
        return (\@grads, $outputs);
    };
}

=head2 grad

    Return function that computes gradient of arguments.

    Parameters
    ----------
    func: a perl sub
        The forward (loss) function.
    argnum: an int or arry ref of int
        The index of argument to calculate gradient for.

    Returns
    -------
    grad_func: a perl function
        A function that would compute the gradient of arguments.
=cut


method grad(CodeRef $func, Maybe[Int|ArrayRef[Int]] $argnum=)
{
    my $grad_with_loss_func = __PACKAGE__->grad_and_loss($func, $argnum);
    return sub {
        return ($grad_with_loss_func->(@_))[0];
    };
}

method train_section(CodeRef $sub)
{
    my $prev = __PACKAGE__->set_is_training(1);
    $sub->();
    __PACKAGE__->set_is_training(0) unless $prev;
}

method test_section(CodeRef $sub)
{
    my $prev = __PACKAGE__->set_is_training(0);
    $sub->();
    __PACKAGE__->set_is_training(1) if $prev;
}

1;

 view all matches for this distribution
 view release on metacpan -  search on metacpan

( run in 0.499 second using v1.00-cache-2.02-grep-82fe00e-cpan-2c419f77a38b )