AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Symbol.pm  view on Meta::CPAN

=head2 Variable

    Create a symbolic variable with specified name.

    Parameters
    ----------
    name : str
        Name of the variable.
    attr : hash ref of string -> string
        Additional attributes to set on the variable.
    shape : array ref of positive integers
        Optionally, one can specify the shape of a variable. This will be used during
        shape inference. If user specified a different shape for this variable using
        keyword argument when calling shape inference, this shape information will be ignored.
    lr_mult : float
        Specify learning rate muliplier for this variable.
    wd_mult : float
        Specify weight decay muliplier for this variable.
    dtype : Dtype
        Similar to shape, we can specify dtype for this variable.
    init : initializer (mx->init->*)
        Specify initializer for this variable to override the default initializer
    kwargs : hash ref
        other additional attribute variables
    Returns
    -------
    variable : Symbol
        The created variable symbol.
=cut

method Variable(
    Str                            $name,
    HashRef[Str]                  :$attr={},
    Maybe[Shape]                  :$shape=,
    Maybe[Num]                    :$lr_mult=,
    Maybe[Num]                    :$wd_mult=,
    Maybe[Dtype]                  :$dtype=,
    Maybe[Initializer]            :$init=,
    HashRef[Str]                  :$kwargs={},
    Maybe[Str]                    :$__layout__=
)
{
    my $handle = check_call(AI::MXNetCAPI::SymbolCreateVariable($name));
    my $ret = __PACKAGE__->new(handle => $handle);
    $attr = AI::MXNet::Symbol::AttrScope->current->get($attr);
    $attr->{__shape__}   = "(".join(',', @{ $shape }).")" if $shape;
    $attr->{__lr_mult__} =  $lr_mult if defined $lr_mult;
    $attr->{__wd_mult__} =  $wd_mult if defined $wd_mult;
    $attr->{__dtype__}   = DTYPE_STR_TO_MX->{ $dtype } if $dtype;
    $attr->{__init__}    = "$init" if defined $init;
    $attr->{__layout__}  = $__layout__ if defined $__layout__;
    while(my ($k, $v) = each %{ $kwargs })
    {
        if($k =~ /^__/ and $k =~ /__$/)
        {
            $attr->{$k} = "$v";
        }
        else
        {
            confess("Attribute name=$k is not supported.".
                    ' Additional attributes must start and end with double underscores,'.
                    ' e.g, __yourattr__'
            );
        }
    }
    $ret->_set_attr(%{ $attr });
    return $ret;
}

=head2 var

    A synonym to Variable.
=cut

*var = \&Variable;

=head2 Group

    Create a symbol that groups symbols together.

    Parameters
    ----------
    symbols : array ref
        List of symbols to be grouped.

    Returns
    -------
    sym : Symbol
        The created group symbol.
=cut

method Group(ArrayRef[AI::MXNet::Symbol] $symbols)
{
    my @handles = map { $_->handle } @{ $symbols };
    my $handle = check_call(AI::MXNetCAPI::SymbolCreateGroup(scalar(@handles), \@handles));
    return __PACKAGE__->new(handle => $handle);
}

=head2 load

    Load symbol from a JSON file.

    You can also use Storable to do the job if you only work with Perl.
    The advantage of load/save is the file is language agnostic.
    This means the file saved using save can be loaded by other language binding of mxnet.
    You also get the benefit being able to directly load/save from cloud storage(S3, HDFS)

    Parameters
    ----------
    fname : str
        The name of the file, examples:

        - `s3://my-bucket/path/my-s3-symbol`
        - `hdfs://my-bucket/path/my-hdfs-symbol`
        - `/path-to/my-local-symbol`

    Returns
    -------
    sym : Symbol
        The loaded symbol.

    See Also
    --------
    AI::MXNet::Symbol->save : Used to save symbol into file.
=cut

method load(Str $fname)
{
    my $handle = check_call(AI::MXNetCAPI::SymbolCreateFromFile($fname));
    return __PACKAGE__->new(handle => $handle);
}

=head2 load_json
    Load symbol from json string.

    Parameters
    ----------
    json_str : str
        A json string.

    Returns
    -------
    sym : Symbol
        The loaded symbol.

    See Also
    --------
    AI::MXNet::Symbol->tojson : Used to save symbol into json string.
=cut

method load_json(Str $json)
{
    my $handle = check_call(AI::MXNetCAPI::SymbolCreateFromJSON($json));
    return __PACKAGE__->new(handle => $handle);
}

method zeros(Shape :$shape, Dtype :$dtype='float32', Maybe[Str] :$name=, Maybe[Str] :$__layout__=)
{
    return __PACKAGE__->_zeros({ shape => $shape, dtype => $dtype, name => $name, ($__layout__ ? (__layout__ => $__layout__) : ()) });
}

method ones(Shape :$shape, Dtype :$dtype='float32', Maybe[Str] :$name=, Maybe[Str] :$__layout__=)
{
    return __PACKAGE__->_ones({ shape => $shape, dtype => $dtype, name => $name, ($__layout__ ? (__layout__ => $__layout__) : ()) });
}

=head2 arange

    Simlar function in the MXNet ndarray as numpy.arange
        See Also https://docs.scipy.org/doc/numpy/reference/generated/numpy.arange.html.

    Parameters
    ----------
    start : number
        Start of interval. The interval includes this value. The default start value is 0.
    stop : number, optional
        End of interval. The interval does not include this value.
    step : number, optional
        Spacing between values
    repeat : int, optional
        "The repeating time of all elements.
        E.g repeat=3, the element a will be repeated three times --> a, a, a.
    dtype : type, optional
        The value type of the NDArray, default to np.float32

    Returns
    -------
    out : Symbol
        The created Symbol
=cut

method arange(Index :$start=0, Index :$stop=, Num :$step=1.0, Index :$repeat=1, Maybe[Str] :$name=, Dtype :$dtype='float32')
{
    return __PACKAGE__->_arange({
                 start => $start, (defined $stop ? (stop => $stop) : ()),
                 step => $step, repeat => $repeat, name => $name, dtype => $dtype
    });
}


sub _parse_arguments
{
    my $type = shift;
    my @args = @_;
    my $type_c = find_type_constraint($type);
    my $str_c  = find_type_constraint("Str");
    my @positional_arguments;
    my %kwargs;
    my @kwargs_order;
    my $only_dtypes_and_undefs = (@args == grep { not defined($_) or $type_c->check($_) } @args);
    my $only_dtypes_and_strs   = (@args == grep { $type_c->check($_) or $str_c->check($_) } @args);
    if(@args % 2 and $only_dtypes_and_undefs)
    {
        @positional_arguments = @args;
    }
    else
    {
        if($only_dtypes_and_undefs)
        {
            @positional_arguments = @args;
        }
        elsif($only_dtypes_and_strs)
        {
            my %tmp = @args;
            if(values(%tmp) == grep { $type_c->check($_) } values(%tmp))
            {
                %kwargs = %tmp;
                my $i = 0;
                @kwargs_order = grep { $i ^= 1 } @args;
            }
            else
            {
                confess("Argument need to be of type $type");
            }
        }
        else
        {
            confess("Argument need to be one type $type");
        }
    }
    return (\@positional_arguments, \%kwargs, \@kwargs_order);
}

sub  _ufunc_helper
{
    my ($lhs, $rhs, $fn_symbol, $lfn_scalar, $rfn_scalar, $reverse) = @_;
    ($rhs, $lhs) = ($lhs, $rhs) if $reverse and $rfn_scalar;
    if(not ref $lhs)
    {
        if(not $rfn_scalar)
        {
            return __PACKAGE__->can($lfn_scalar)->(__PACKAGE__, $rhs, { "scalar" => $lhs });
        }
        else
        {



( run in 0.780 second using v1.01-cache-2.11-cpan-39bf76dae61 )