AI-MXNet
view release on metacpan - search on metacpan
view release on metacpan or search on metacpan
lib/AI/MXNet/Base.pm view on Meta::CPAN
package AI::MXNet::Base;
use strict;
use warnings;
use PDL;
use PDL::Types qw();
use AI::MXNetCAPI 1.0102;
use AI::NNVMCAPI 1.01;
use AI::MXNet::Types;
use Time::HiRes;
use Carp;
use Exporter;
use base qw(Exporter);
use List::Util qw(shuffle);
@AI::MXNet::Base::EXPORT = qw(product enumerate assert zip check_call build_param_doc
pdl cat dog svd bisect_left pdl_shuffle
DTYPE_STR_TO_MX DTYPE_MX_TO_STR DTYPE_MX_TO_PDL
DTYPE_PDL_TO_MX DTYPE_MX_TO_PERL GRAD_REQ_MAP);
@AI::MXNet::Base::EXPORT_OK = qw(pzeros pceil);
use constant DTYPE_STR_TO_MX => {
float32 => 0,
float64 => 1,
float16 => 2,
uint8 => 3,
int32 => 4
};
use constant DTYPE_MX_TO_STR => {
0 => 'float32',
1 => 'float64',
2 => 'float16',
3 => 'uint8',
4 => 'int32'
};
use constant DTYPE_MX_TO_PDL => {
0 => 6,
1 => 7,
2 => 6,
3 => 0,
4 => 3,
float32 => 6,
float64 => 7,
float16 => 6,
uint8 => 0,
int32 => 3
};
use constant DTYPE_PDL_TO_MX => {
6 => 0,
7 => 1,
0 => 3,
3 => 4,
};
use constant DTYPE_MX_TO_PERL => {
0 => 'f',
1 => 'd',
2 => 'S',
3 => 'C',
4 => 'l',
float32 => 'f',
float64 => 'd',
float16 => 'S',
uint8 => 'C',
int32 => 'l'
};
use constant GRAD_REQ_MAP => {
null => 0,
write => 1,
add => 3
};
=head1 NAME
AI::MXNet::Base - Helper functions
=head1 DEFINITION
Helper functions
=head2 zip
Perl version of for x,y,z in zip (arr_x, arr_y, arr_z)
Parameters
----------
$sub_ref, called with @_ filled with $arr_x->[$i], $arr_y->[$i], $arr_z->[$i]
for each loop iteration.
@array_refs
=cut
sub zip
{
my ($sub, @arrays) = @_;
my $len = @{ $arrays[0] };
for (my $i = 0; $i < $len; $i++)
{
$sub->(map { $_->[$i] } @arrays);
}
}
=head2 enumerate
Same as zip, but the argument list in the anonymous sub is prepended
by the iteration count.
=cut
sub enumerate
{
my ($sub, @arrays) = @_;
my $len = @{ $arrays[0] };
zip($sub, [0..$len-1], @arrays);
}
=head2 product
Calculates the product of the input agruments.
=cut
sub product
{
my $p = 1;
map { $p = $p * $_ } @_;
return $p;
}
=head2 bisect_left
https://hg.python.org/cpython/file/2.7/Lib/bisect.py
=cut
sub bisect_left
{
my ($a, $x, $lo, $hi) = @_;
$lo //= 0;
$hi //= @{ $a };
if($lo < 0)
{
Carp::confess('lo must be non-negative');
}
while($lo < $hi)
{
my $mid = int(($lo+$hi)/2);
if($a->[$mid] < $x)
{
$lo = $mid+1;
}
else
{
$hi = $mid;
}
}
return $lo;
}
=head2 pdl_shuffle
Shuffle the pdl by the last dimension
Parameters
-----------
PDL $pdl
$preshuffle Maybe[ArrayRef[Index]], if defined the array elements are used
as shuffled last dimension's indexes
=cut
sub pdl_shuffle
{
my ($pdl, $preshuffle) = @_;
my $c = $pdl->copy;
my @shuffle = $preshuffle ? @{ $preshuffle } : shuffle(0..$pdl->dim(-1)-1);
my $rem = $pdl->ndims-1;
for my $i (0..$pdl->dim(-1)-1)
{
$c->slice(('X')x$rem, $i) .= $pdl->slice(('X')x$rem, $shuffle[$i])
}
$c;
}
=head2 assert
Parameters
-----------
Bool $input
Str $error_str
Calls Carp::confess with $error_str//"AssertionError" if the $input is false
=cut
sub assert
{
my ($input, $error_str) = @_;
local($Carp::CarpLevel) = 1;
Carp::confess($error_str//'AssertionError')
unless $input;
}
=head2 check_call
Checks the return value of C API call
This function will raise an exception when error occurs.
Every API call is wrapped with this function.
Returns the C API call return values stripped of first return value,
checks for return context and returns first element in
the values list when called in scalar context.
=cut
sub check_call
{
Carp::confess(AI::MXNetCAPI::GetLastError()) if shift;
return wantarray ? @_ : $_[0];
}
=head2 build_param_doc
Builds argument docs in python style.
arg_names : array ref of str
Argument names.
arg_types : array ref of str
Argument type information.
arg_descs : array ref of str
Argument description information.
remove_dup : boolean, optional
Whether to remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
=cut
sub build_param_doc
{
my ($arg_names, $arg_types, $arg_descs, $remove_dup) = @_;
$remove_dup //= 1;
my %param_keys;
my @param_str;
zip(sub {
my ($key, $type_info, $desc) = @_;
return if exists $param_keys{$key} and $remove_dup;
$param_keys{$key} = 1;
my $ret = sprintf("%s : %s", $key, $type_info);
$ret .= "\n ".$desc if length($desc);
push @param_str, $ret;
},
$arg_names, $arg_types, $arg_descs
);
return sprintf("Parameters\n----------\n%s\n", join("\n", @param_str));
}
=head2 _notify_shutdown
Notify MXNet about shutdown.
=cut
sub _notify_shutdown
{
check_call(AI::MXNetCAPI::NotifyShutdown());
}
END {
_notify_shutdown();
Time::HiRes::sleep(0.01);
}
*pzeros = \&zeros;
*pceil = \&ceil;
## making sure that we can stringify arbitrarily large piddles
$PDL::toolongtoprint = 1000_000_000;
1;
view all matches for this distributionview release on metacpan - search on metacpan
( run in 0.523 second using v1.00-cache-2.02-grep-82fe00e-cpan-2c419f77a38b )