AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Symbol.pm  view on Meta::CPAN


    Returns
    -------
    $out : array ref of strings.
=cut

method list_outputs()
{
    return scalar(check_call(AI::MXNetCAPI::SymbolListOutputs($self->handle)));
}


=head2 list_auxiliary_states()

    List all auxiliary states in the symbol.

    Returns
    -------
    aux_states : array ref of string
        List the names of the auxiliary states.

    Notes
    -----
    Auxiliary states are special states of symbols that do not corresponds to an argument,
    and do not have gradient. But still be useful for the specific operations.
    A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm.
    Most operators do not have Auxiliary states.
=cut

method list_auxiliary_states()
{
    return scalar(check_call(AI::MXNetCAPI::SymbolListAuxiliaryStates($self->handle)));
}


=head2 list_inputs

    Lists all arguments and auxiliary states of this Symbol.

    Returns
    -------
    inputs : array ref of str
    List of all inputs.

    Examples
    --------
    >>> my $bn = mx->sym->BatchNorm(name=>'bn');
=cut

method list_inputs()
{
    return scalar(check_call(AI::NNVMCAPI::SymbolListInputNames($self->handle, 0)));
}

=head2 infer_type

        Infer the type of outputs and arguments of given known types of arguments.

        User can either pass in the known types in positional way or keyword argument way.
        Tuple of Nones is returned if there is not enough information passed in.
        An error will be raised if there is inconsistency found in the known types passed in.

        Parameters
        ----------
        args : Array
            Provide type of arguments in a positional way.
            Unknown type can be marked as None

        kwargs : Hash ref, must ne ssupplied as as sole argument to the method.
            Provide keyword arguments of known types.

        Returns
        -------
        arg_types : array ref of Dtype or undef
            List of types of arguments.
            The order is in the same order as list_arguments()
        out_types : array ref of Dtype or undef
            List of types of outputs.
            The order is in the same order as list_outputs()
        aux_types : array ref of Dtype or undef
            List of types of outputs.
            The order is in the same order as list_auxiliary()
=cut


method infer_type(Str|Undef @args)
{
    my ($positional_arguments, $kwargs, $kwargs_order) = _parse_arguments("Dtype", @args); 
    my $sdata = [];
    my $keys  = [];
    if(@$positional_arguments)
    {
        @{ $sdata } = map { defined($_) ? DTYPE_STR_TO_MX->{ $_ } : -1 } @{ $positional_arguments };
    }
    else
    {
        @{ $keys }  = @{ $kwargs_order };
        @{ $sdata } = map { DTYPE_STR_TO_MX->{ $_ } } @{ $kwargs }{ @{ $kwargs_order } };
    }
    my ($arg_type, $out_type, $aux_type, $complete) = check_call(AI::MXNetCAPI::SymbolInferType(
            $self->handle,
            scalar(@{ $sdata }),
            $keys,
            $sdata
        )
    );
    if($complete)
    {
        return (
            [ map { DTYPE_MX_TO_STR->{ $_ } } @{ $arg_type }],
            [ map { DTYPE_MX_TO_STR->{ $_ } } @{ $out_type }],
            [ map { DTYPE_MX_TO_STR->{ $_ } } @{ $aux_type }]
        );
    }
    else
    {
        return (undef, undef, undef);
    }
}

=head2 infer_shape

        Infer the shape of outputs and arguments of given known shapes of arguments.

        User can either pass in the known shapes in positional way or keyword argument way.
        Tuple of Nones is returned if there is not enough information passed in.
        An error will be raised if there is inconsistency found in the known shapes passed in.

        Parameters
        ----------
        *args :
            Provide shape of arguments in a positional way.
            Unknown shape can be marked as undef

        **kwargs :
            Provide keyword arguments of known shapes.

        Returns
        -------
        arg_shapes : array ref of Shape or undef
            List of shapes of arguments.
            The order is in the same order as list_arguments()
        out_shapes : array ref of Shape or undef
            List of shapes of outputs.
            The order is in the same order as list_outputs()
        aux_shapes : array ref of Shape or undef
            List of shapes of outputs.
            The order is in the same order as list_auxiliary()
=cut

method infer_shape(Maybe[Str|Shape] @args)
{
    my @res = $self->_infer_shape_impl(0, @args);
    if(not defined $res[1])
    {
        my ($arg_shapes) = $self->_infer_shape_impl(1, @args);
        my $arg_names    = $self->list_arguments;
        my @unknowns;
        zip(sub {
            my ($name, $shape) = @_;
            if(not ref $shape or not @$shape or not product(@$shape))
            {
                if(@unknowns >= 10)
                {
                    $unknowns[10] = '...';
                }
                else
                {
                    my @shape = eval { @$shape };
                    push @unknowns, "$name @shape";
                }
            }
        }, $arg_names, $arg_shapes);
        AI::MXNet::Logging->warning(
            "Cannot decide shape for the following arguments "
            ."(0s in shape means unknown dimensions). "
            ."Consider providing them as input:\n\t"
            ."\n\t"
            .join(", ", @unknowns)
        );
    }
    return @res;
}

=head2 infer_shape_partial

    Partially infer the shape. The same as infer_shape, except that the partial

lib/AI/MXNet/Symbol.pm  view on Meta::CPAN


    Examples:
    my $result = $symbol->eval(ctx => mx->gpu, args => {data => mx->nd->ones([5,5])});
    my $result = $symbol->eval(args => {data => mx->nd->ones([5,5])});

=cut

method eval(:$ctx=AI::MXNet::Context->cpu, HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray] :$args)
{
    return $self->bind(ctx => $ctx, args => $args)->forward;
}

=head2  grad

    Get the autodiff of current symbol.
    This function can only be used if current symbol is a loss function.

    Parameters
    ----------
    $wrt : Array of String
        keyword arguments of the symbol that the gradients are taken.

    Returns
    -------
    grad : AI::MXNet::Symbol
        A gradient Symbol with returns to be the corresponding gradients.
=cut

method grad(ArrayRef[Str] $wrt)
{
    my $handle = check_call(AI::MXNetCAPI::SymbolGrad(
                    $self->handle,
                    scalar(@$wrt),
                    $wrt
                 )
    );
    return __PACKAGE__->new(handle => $handle);
}

=head2 Variable

    Create a symbolic variable with specified name.

    Parameters
    ----------
    name : str
        Name of the variable.
    attr : hash ref of string -> string
        Additional attributes to set on the variable.
    shape : array ref of positive integers
        Optionally, one can specify the shape of a variable. This will be used during
        shape inference. If user specified a different shape for this variable using
        keyword argument when calling shape inference, this shape information will be ignored.
    lr_mult : float
        Specify learning rate muliplier for this variable.
    wd_mult : float
        Specify weight decay muliplier for this variable.
    dtype : Dtype
        Similar to shape, we can specify dtype for this variable.
    init : initializer (mx->init->*)
        Specify initializer for this variable to override the default initializer
    kwargs : hash ref
        other additional attribute variables
    Returns
    -------
    variable : Symbol
        The created variable symbol.
=cut

method Variable(
    Str                            $name,
    HashRef[Str]                  :$attr={},
    Maybe[Shape]                  :$shape=,
    Maybe[Num]                    :$lr_mult=,
    Maybe[Num]                    :$wd_mult=,
    Maybe[Dtype]                  :$dtype=,
    Maybe[Initializer]            :$init=,
    HashRef[Str]                  :$kwargs={},
    Maybe[Str]                    :$__layout__=
)
{
    my $handle = check_call(AI::MXNetCAPI::SymbolCreateVariable($name));
    my $ret = __PACKAGE__->new(handle => $handle);
    $attr = AI::MXNet::Symbol::AttrScope->current->get($attr);
    $attr->{__shape__}   = "(".join(',', @{ $shape }).")" if $shape;
    $attr->{__lr_mult__} =  $lr_mult if defined $lr_mult;
    $attr->{__wd_mult__} =  $wd_mult if defined $wd_mult;
    $attr->{__dtype__}   = DTYPE_STR_TO_MX->{ $dtype } if $dtype;
    $attr->{__init__}    = "$init" if defined $init;
    $attr->{__layout__}  = $__layout__ if defined $__layout__;
    while(my ($k, $v) = each %{ $kwargs })
    {
        if($k =~ /^__/ and $k =~ /__$/)
        {
            $attr->{$k} = "$v";
        }
        else
        {
            confess("Attribute name=$k is not supported.".
                    ' Additional attributes must start and end with double underscores,'.
                    ' e.g, __yourattr__'
            );
        }
    }
    $ret->_set_attr(%{ $attr });
    return $ret;
}

=head2 var

    A synonym to Variable.
=cut

*var = \&Variable;

=head2 Group

    Create a symbol that groups symbols together.

    Parameters
    ----------



( run in 1.874 second using v1.01-cache-2.11-cpan-39bf76dae61 )