AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Symbol.pm  view on Meta::CPAN

package AI::MXNet::Symbol;

=head1 NAME

    AI::MXNet::Symbol - Symbolic interface of MXNet.
=cut

use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::Symbol::Base;
use AI::MXNet::Types;
use Mouse;
use AI::MXNet::Function::Parameters;
use overload
    '""'  => \&stringify,
    '+'   => \&add,
    '-'   => \&subtract,
    '*'   => \&multiply,
    '/'   => \&divide,
    '/='  => \&idivide,
    '**'  => \&power,
    '%'   => \&mod,
    '=='  => \&equal,
    '!='  => \&not_equal,
    '>'   => \&greater,
    '>='  => \&greater_equal,
    '<'   => \&lesser,
    '<='  => \&lesser_equal,
    '&{}' => sub { my $self = shift; sub { $self->call(@_) } },
    '@{}' => sub { my $self = shift; [map { $self->slice($_) } @{ $self->list_outputs }] };

extends 'AI::MXNet::Symbol::Base';
has 'handle'   => (is => 'rw', isa => 'SymbolHandle', required => 1);

sub DEMOLISH
{
    check_call(AI::NNVMCAPI::SymbolFree(shift->handle));
}

method STORABLE_freeze($cloning)
{
    return $self->tojson();
}

method STORABLE_thaw($cloning, $json)
{
    my $handle = check_call(
        AI::MXNetCAPI::SymbolCreateFromJSON(
            $json
        )
    );
    $self->handle($handle);
}

method stringify($other=, $reverse=)
{
    my $name = $self->name;
    sprintf("<%s %s>", ref($self), $name ? $name : 'Grouped');
}

method add(AI::MXNet::Symbol|Num $other, $reverse=)
{
    return _ufunc_helper(
        $self,
        $other,
        qw/_Plus _PlusScalar/
    );
}

method subtract(AI::MXNet::Symbol|Num $other, $reverse=)
{
    return _ufunc_helper(
        $self,
        $other,
        qw/_Minus _MinusScalar _RMinusScalar/,
        $reverse
    );
}

method multiply(AI::MXNet::Symbol|Num $other, $reverse=)
{
    return _ufunc_helper(
        $self,
        $other,
        qw/_Mul _MulScalar/
    );
}

method divide(AI::MXNet::Symbol|Num $other, $reverse=)
{
    return _ufunc_helper(
        $self,
        $other,
        qw/_Div _DivScalar _RDivScalar/,
        $reverse
    );
}

method power(AI::MXNet::Symbol|Num $other, $reverse=)
{
    return _ufunc_helper(
        $self,
        $other,
        qw/_Power _PowerScalar _RPowerScalar/,
        $reverse
    );
}

lib/AI/MXNet/Symbol.pm  view on Meta::CPAN

            scalar(@{ $indptr }) - 1,
            $keys,
            $indptr,
            $sdata,
        )
    );
    if($complete)
    {
        return $arg_shapes, $out_shapes, $aux_shapes;
    }
    else
    {
        return (undef, undef, undef);
    }
}

=head2 debug_str

    The debug string.

    Returns
    -------
    debug_str : string
        Debug string of the symbol.
=cut

method debug_str()
{
    return scalar(check_call(AI::MXNetCAPI::SymbolPrint($self->handle)));
}

=head2 save

        Save the symbol into a file.

        You can also use Storable to do the job if you only work with Perl.
        The advantage of load/save is the file is language agnostic.
        This means the file saved using save can be loaded by other language binding of mxnet.
        You also get the benefit being able to directly load/save from cloud storage(S3, HDFS)

        Parameters
        ----------
        fname : str
            The name of the file
            - s3://my-bucket/path/my-s3-symbol
            - hdfs://my-bucket/path/my-hdfs-symbol
            - /path-to/my-local-symbol

        See Also
        --------
        load : Used to load symbol from file.
=cut

method save(Str $fname)
{
    check_call(AI::MXNetCAPI::SymbolSaveToFile($self->handle, $fname));
}

=head2 tojson

        Save the symbol into a JSON string.

        See Also
        --------
        load_json : Used to load symbol from JSON string.
=cut

method tojson()
{
    return scalar(check_call(AI::MXNetCAPI::SymbolSaveToJSON($self->handle)));
}

method _get_ndarray_inputs(
    Str                                                      $arg_key,
    HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray] $args,
    ArrayRef[Str]                                            $arg_names,
    Bool                                                     $allow_missing=0
)
{
    my ($arg_handles, $arg_arrays) = ([], []);
    if(ref $args eq 'ARRAY')
    {
        confess("Length of $arg_key do not match number of arguments") 
            unless @$args == @$arg_names;
        @{ $arg_handles } = map { $_->handle } @{ $args };
        $arg_arrays = $args;
    }
    else
    {
        my %tmp = ((map { $_ => undef } @$arg_names), %$args);
        if(not $allow_missing and grep { not defined } values %tmp)
        {
            my ($missing) = grep { not defined $tmp{ $_ } } (keys %tmp);
            confess("key $missing is missing in $arg_key");
        }
        for my $name (@$arg_names)
        {
            push @$arg_handles, defined($tmp{ $name }) ? $tmp{ $name }->handle : undef;
            push @$arg_arrays, defined($tmp{ $name }) ? $tmp{ $name } : undef;
        }
    }
    return ($arg_handles, $arg_arrays);
}

=head2 simple_bind

    Bind current symbol to get an executor, allocate all the ndarrays needed.
    Allows specifying data types.

    This function will ask user to pass in ndarray of position
    they like to bind to, and it will automatically allocate the ndarray
    for arguments and auxiliary states that user did not specify explicitly.

    Parameters
    ----------
    :$ctx : AI::MXNet::Context
        The device context the generated executor to run on.

    :$grad_req: string
        {'write', 'add', 'null'}, or list of str or dict of str to str, optional
        Specifies how we should update the gradient to the args_grad.
            - 'write' means everytime gradient is write to specified args_grad NDArray.
            - 'add' means everytime gradient is add to the specified NDArray.
            - 'null' means no action is taken, the gradient may not be calculated.

    :$type_dict  : hash ref of str->Dtype
        Input type map, name->dtype

    :$group2ctx : hash ref of string to AI::MXNet::Context
        The mapping of the ctx_group attribute to the context assignment.

lib/AI/MXNet/Symbol.pm  view on Meta::CPAN

)
{
    my $handle = check_call(AI::MXNetCAPI::SymbolCreateVariable($name));
    my $ret = __PACKAGE__->new(handle => $handle);
    $attr = AI::MXNet::Symbol::AttrScope->current->get($attr);
    $attr->{__shape__}   = "(".join(',', @{ $shape }).")" if $shape;
    $attr->{__lr_mult__} =  $lr_mult if defined $lr_mult;
    $attr->{__wd_mult__} =  $wd_mult if defined $wd_mult;
    $attr->{__dtype__}   = DTYPE_STR_TO_MX->{ $dtype } if $dtype;
    $attr->{__init__}    = "$init" if defined $init;
    $attr->{__layout__}  = $__layout__ if defined $__layout__;
    while(my ($k, $v) = each %{ $kwargs })
    {
        if($k =~ /^__/ and $k =~ /__$/)
        {
            $attr->{$k} = "$v";
        }
        else
        {
            confess("Attribute name=$k is not supported.".
                    ' Additional attributes must start and end with double underscores,'.
                    ' e.g, __yourattr__'
            );
        }
    }
    $ret->_set_attr(%{ $attr });
    return $ret;
}

=head2 var

    A synonym to Variable.
=cut

*var = \&Variable;

=head2 Group

    Create a symbol that groups symbols together.

    Parameters
    ----------
    symbols : array ref
        List of symbols to be grouped.

    Returns
    -------
    sym : Symbol
        The created group symbol.
=cut

method Group(ArrayRef[AI::MXNet::Symbol] $symbols)
{
    my @handles = map { $_->handle } @{ $symbols };
    my $handle = check_call(AI::MXNetCAPI::SymbolCreateGroup(scalar(@handles), \@handles));
    return __PACKAGE__->new(handle => $handle);
}

=head2 load

    Load symbol from a JSON file.

    You can also use Storable to do the job if you only work with Perl.
    The advantage of load/save is the file is language agnostic.
    This means the file saved using save can be loaded by other language binding of mxnet.
    You also get the benefit being able to directly load/save from cloud storage(S3, HDFS)

    Parameters
    ----------
    fname : str
        The name of the file, examples:

        - `s3://my-bucket/path/my-s3-symbol`
        - `hdfs://my-bucket/path/my-hdfs-symbol`
        - `/path-to/my-local-symbol`

    Returns
    -------
    sym : Symbol
        The loaded symbol.

    See Also
    --------
    AI::MXNet::Symbol->save : Used to save symbol into file.
=cut

method load(Str $fname)
{
    my $handle = check_call(AI::MXNetCAPI::SymbolCreateFromFile($fname));
    return __PACKAGE__->new(handle => $handle);
}

=head2 load_json
    Load symbol from json string.

    Parameters
    ----------
    json_str : str
        A json string.

    Returns
    -------
    sym : Symbol
        The loaded symbol.

    See Also
    --------
    AI::MXNet::Symbol->tojson : Used to save symbol into json string.
=cut

method load_json(Str $json)
{
    my $handle = check_call(AI::MXNetCAPI::SymbolCreateFromJSON($json));
    return __PACKAGE__->new(handle => $handle);
}

method zeros(Shape :$shape, Dtype :$dtype='float32', Maybe[Str] :$name=, Maybe[Str] :$__layout__=)
{
    return __PACKAGE__->_zeros({ shape => $shape, dtype => $dtype, name => $name, ($__layout__ ? (__layout__ => $__layout__) : ()) });
}

method ones(Shape :$shape, Dtype :$dtype='float32', Maybe[Str] :$name=, Maybe[Str] :$__layout__=)
{
    return __PACKAGE__->_ones({ shape => $shape, dtype => $dtype, name => $name, ($__layout__ ? (__layout__ => $__layout__) : ()) });
}

=head2 arange

    Simlar function in the MXNet ndarray as numpy.arange
        See Also https://docs.scipy.org/doc/numpy/reference/generated/numpy.arange.html.

    Parameters
    ----------
    start : number
        Start of interval. The interval includes this value. The default start value is 0.
    stop : number, optional
        End of interval. The interval does not include this value.
    step : number, optional
        Spacing between values
    repeat : int, optional
        "The repeating time of all elements.
        E.g repeat=3, the element a will be repeated three times --> a, a, a.
    dtype : type, optional
        The value type of the NDArray, default to np.float32

    Returns
    -------
    out : Symbol
        The created Symbol
=cut

method arange(Index :$start=0, Index :$stop=, Num :$step=1.0, Index :$repeat=1, Maybe[Str] :$name=, Dtype :$dtype='float32')
{
    return __PACKAGE__->_arange({
                 start => $start, (defined $stop ? (stop => $stop) : ()),
                 step => $step, repeat => $repeat, name => $name, dtype => $dtype
    });
}


sub _parse_arguments
{
    my $type = shift;
    my @args = @_;
    my $type_c = find_type_constraint($type);
    my $str_c  = find_type_constraint("Str");
    my @positional_arguments;
    my %kwargs;
    my @kwargs_order;
    my $only_dtypes_and_undefs = (@args == grep { not defined($_) or $type_c->check($_) } @args);
    my $only_dtypes_and_strs   = (@args == grep { $type_c->check($_) or $str_c->check($_) } @args);
    if(@args % 2 and $only_dtypes_and_undefs)
    {



( run in 1.363 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )