AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Contrib/AutoGrad.pm view on Meta::CPAN
argnum: an int or a array ref of int
The index of argument to calculate gradient for.
Returns
-------
grad_and_loss_func: a perl sub
A function that would compute both the gradient of arguments and loss value.
=cut
method grad_and_loss(CodeRef $func, Maybe[Int|ArrayRef[Int]] $argnum=)
{
return sub {
my @args = @_;
my @variables = @_;
if(defined $argnum)
{
my @argnum = ref $argnum ? @$argnum : ($argnum);
@variables = map { $_[$_] } @argnum;
}
map {
assert(
(blessed($_) and $_->isa('AI::MXNet::NDArray')),
"type of autograd input should NDArray")
} @variables;
my @grads = map { $_->zeros_like } @variables;
__PACKAGE__->mark_variables(\@variables, \@grads);
my $prev = __PACKAGE__->set_is_training(1);
my $outputs = $func->(@args);
__PACKAGE__->set_is_training(0) unless $prev;
__PACKAGE__->compute_gradient(ref $outputs eq 'ARRAY' ? $outputs : [$outputs]);
return (\@grads, $outputs);
};
}
=head2 grad
Return function that computes gradient of arguments.
Parameters
----------
func: a perl sub
The forward (loss) function.
argnum: an int or arry ref of int
The index of argument to calculate gradient for.
Returns
-------
grad_func: a perl function
A function that would compute the gradient of arguments.
=cut
method grad(CodeRef $func, Maybe[Int|ArrayRef[Int]] $argnum=)
{
my $grad_with_loss_func = __PACKAGE__->grad_and_loss($func, $argnum);
return sub {
return ($grad_with_loss_func->(@_))[0];
};
}
method train_section(CodeRef $sub)
{
my $prev = __PACKAGE__->set_is_training(1);
$sub->();
__PACKAGE__->set_is_training(0) unless $prev;
}
method test_section(CodeRef $sub)
{
my $prev = __PACKAGE__->set_is_training(0);
$sub->();
__PACKAGE__->set_is_training(1) if $prev;
}
1;
( run in 0.648 second using v1.01-cache-2.11-cpan-39bf76dae61 )