AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/RNN/Cell.pm  view on Meta::CPAN

method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states)
{
    confess("Not Implemented");
}

method _gate_names()
{
    [''];
}

=head2 params

    Parameters of this cell
=cut

method params()
{
    $self->_own_params(0);
    return $self->_params;
}

=head2 state_shape

    shape(s) of states
=cut

method state_shape()
{
    return [map { $_->{shape} } @{ $self->state_info }];
}

=head2 state_info

    shape and layout information of states
=cut

method state_info()
{
    confess("Not Implemented");
}

=head2 begin_state

    Initial state for this cell.

    Parameters
    ----------
    :$func : sub ref, default is AI::MXNet::Symbol->can('zeros')
        Function for creating initial state.
        Can be AI::MXNet::Symbol->can('zeros'),
        AI::MXNet::Symbol->can('uniform'), AI::MXNet::Symbol->can('Variable') etc.
        Use AI::MXNet::Symbol->can('Variable') if you want to directly
        feed the input as states.
    @kwargs :
        more keyword arguments passed to func. For example
        mean, std, dtype, etc.

    Returns
    -------
    $states : ArrayRef[AI::MXNet::Symbol]
        starting states for first RNN step
=cut

method begin_state(CodeRef :$func=AI::MXNet::Symbol->can('zeros'), @kwargs)
{
    assert(
        (not $self->_modified),
        "After applying modifier cells (e.g. DropoutCell) the base "
        ."cell cannot be called directly. Call the modifier cell instead."
    );
    my @states;
    my $func_needs_named_name = $func ne AI::MXNet::Symbol->can('Variable');
    for my $info (@{ $self->state_info })
    {
        $self->_init_counter($self->_init_counter + 1);
        my @name = (sprintf("%sbegin_state_%d", $self->_prefix, $self->_init_counter));
        my %info = %{ $info//{} };
        if($func_needs_named_name)
        {
            unshift(@name, 'name');
        }
        else
        {
            if(exists $info{__layout__})
            {
                $info{kwargs} = { __layout__ => delete $info{__layout__} };
            }
        }
        my %kwargs = (@kwargs, %info);
        my $state = &{$func}(
            'AI::MXNet::Symbol',
            @name,
            %kwargs
        );
        push @states, $state;
    }
    return \@states;
}

=head2 unpack_weights

    Unpack fused weight matrices into separate
    weight matrices

    Parameters
    ----------
    $args : HashRef[AI::MXNet::NDArray]
        hash ref containing packed weights.
        usually from AI::MXNet::Module->get_output()

    Returns
    -------
    $args : HashRef[AI::MXNet::NDArray]
        hash ref with weights associated with
        this cell, unpacked.
=cut

method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
{
    my %args = %{ $args };
    my $h = $self->_num_hidden;



( run in 2.440 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )