AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/RNN/Cell.pm  view on Meta::CPAN

            AI::MXNet::RNN::Cell->new(
                num_hidden => $self->_num_hidden,
                activation => 'tanh',
                prefix     => shift
            )
        },
        lstm     => sub {
            AI::MXNet::RNN::LSTMCell->new(
                num_hidden => $self->_num_hidden,
                prefix     => shift
            )
        },
        gru      => sub {
            AI::MXNet::RNN::GRUCell->new(
                num_hidden => $self->_num_hidden,
                prefix     => shift
            )
        },
    }->{ $self->_mode };
    for my $i (0..$self->_num_layers-1)
    {
        if($self->_bidirectional)
        {
            $stack->add(
                AI::MXNet::RNN::BidirectionalCell->new(
                    $get_cell->(sprintf('%sl%d_', $self->_prefix, $i)),
                    $get_cell->(sprintf('%sr%d_', $self->_prefix, $i)),
                    output_prefix => sprintf('%sbi_%s_%d', $self->_prefix, $self->_mode, $i)
                )
            );
        }
        else
        {
            $stack->add($get_cell->(sprintf('%sl%d_', $self->_prefix, $i)));
        }
    }
    return $stack;
}

package AI::MXNet::RNN::SequentialCell;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::RNN::Cell::Base';

=head1 NAME

    AI:MXNet::RNN::SequentialCell
=cut

=head1 DESCRIPTION

    Sequentially stacking multiple RNN cells

    Parameters
    ----------
    params : AI::MXNet::RNN::Params or undef
        container for weight sharing between cells.
        created if undef.
=cut

has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef);

sub BUILD
{
    my ($self, $original_arguments) = @_;
    $self->_override_cell_params(defined $original_arguments->{params});
    $self->_cells([]);
}

=head2 add

    Append a cell to the stack.

    Parameters
    ----------
    $cell : AI::MXNet::RNN::Cell::Base
=cut

method add(AI::MXNet::RNN::Cell::Base $cell)
{
    push @{ $self->_cells }, $cell;
    if($self->_override_cell_params)
    {
        assert(
            $cell->_own_params,
            "Either specify params for SequentialRNNCell "
            ."or child cells, not both."
        );
        %{ $cell->params->_params } = (%{ $cell->params->_params }, %{ $self->params->_params });
    }
    %{ $self->params->_params } = (%{ $self->params->_params }, %{ $cell->params->_params });
}

method state_info()
{
    return $self->_cells_state_info($self->_cells);
}

method begin_state(@kwargs)
{
    assert(
        (not $self->_modified),
        "After applying modifier cells (e.g. DropoutCell) the base "
        ."cell cannot be called directly. Call the modifier cell instead."
    );
    return $self->_cells_begin_state($self->_cells, @kwargs);
}

method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
{
    return $self->_cells_unpack_weights($self->_cells, $args)
}

method pack_weights(HashRef[AI::MXNet::NDArray] $args)
{
    return $self->_cells_pack_weights($self->_cells, $args);
}

method call($inputs, $states)
{
    $self->_counter($self->_counter + 1);
    my @next_states;
    my $p = 0;
    for my $cell (@{ $self->_cells })
    {
        assert(not $cell->isa('AI::MXNet::BidirectionalCell'));
        my $n = scalar(@{ $cell->state_info });
        my $state = [@{ $states }[$p..$p+$n-1]];
        $p += $n;
        ($inputs, $state) = &{$cell}($inputs, $state);
        push @next_states, $state;
    }
    return ($inputs, [map { @$_} @next_states]);
}

method unroll(
    Int $length,
    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
    Str                                                  :$input_prefix='',
    Str                                                  :$layout='NTC',
    Maybe[Bool]                                          :$merge_outputs=
)
{
    my $num_cells = @{ $self->_cells };
    $begin_state //= $self->begin_state;
    my $p = 0;
    my $states;
    my @next_states;
    enumerate(sub {
        my ($i, $cell) = @_;
        my $n   = @{ $cell->state_info };
        $states = [@{$begin_state}[$p..$p+$n-1]];
        $p += $n;
        ($inputs, $states) = $cell->unroll(
            $length,
            inputs          => $inputs,
            input_prefix    => $input_prefix,
            begin_state     => $states,
            layout          => $layout,
            merge_outputs   => ($i < $num_cells-1) ? undef : $merge_outputs
        );
        push @next_states, $states;
    }, $self->_cells);
    return ($inputs, [map { @{ $_ } } @next_states]);
}

package AI::MXNet::RNN::BidirectionalCell;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::RNN::Cell::Base';

=head1 NAME

    AI::MXNet::RNN::BidirectionalCell
=cut

=head1 DESCRIPTION

    Bidirectional RNN cell

    Parameters
    ----------
    l_cell : AI::MXNet::RNN::Cell::Base
        cell for forward unrolling
    r_cell : AI::MXNet::RNN::Cell::Base
        cell for backward unrolling
    output_prefix : str, default 'bi_'
        prefix for name of output
=cut

has 'l_cell'         => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1);
has 'r_cell'         => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1);
has '_output_prefix' => (is => 'ro', init_arg => 'output_prefix', isa => 'Str', default => 'bi_');
has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef);

around BUILDARGS => sub {
    my $orig  = shift;
    my $class = shift;
    if(@_ >= 2 and blessed $_[0] and blessed $_[1])
    {
        my $l_cell = shift(@_);
        my $r_cell = shift(@_);
        return $class->$orig(
            l_cell => $l_cell,
            r_cell => $r_cell,
            @_
        );
    }
    return $class->$orig(@_);
};

sub BUILD
{
    my ($self, $original_arguments) = @_;
    $self->_override_cell_params(defined $original_arguments->{params});
    if($self->_override_cell_params)
    {
        assert(
            ($self->l_cell->_own_params and $self->r_cell->_own_params),
            "Either specify params for BidirectionalCell ".
            "or child cells, not both."
        );
        %{ $self->l_cell->params->_params } = (%{ $self->l_cell->params->_params }, %{ $self->params->_params });
        %{ $self->r_cell->params->_params } = (%{ $self->r_cell->params->_params }, %{ $self->params->_params });
    }
    %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->l_cell->params->_params });
    %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->r_cell->params->_params });
    $self->_cells([$self->l_cell, $self->r_cell]);
}

method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
{
    return $self->_cells_unpack_weights($self->_cells, $args)
}

method pack_weights(HashRef[AI::MXNet::NDArray] $args)
{
    return $self->_cells_pack_weights($self->_cells, $args);
}

method call($inputs, $states)
{
    confess("Bidirectional cannot be stepped. Please use unroll");
}

method state_info()
{
    return $self->_cells_state_info($self->_cells);
}

method begin_state(@kwargs)
{
    assert((not $self->_modified),
            "After applying modifier cells (e.g. DropoutCell) the base "
            ."cell cannot be called directly. Call the modifier cell instead."
    );
    return $self->_cells_begin_state($self->_cells, @kwargs);
}

method unroll(
    Int $length,
    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
    Str                                                  :$input_prefix='',
    Str                                                  :$layout='NTC',
    Maybe[Bool]                                          :$merge_outputs=
)
{

    my $axis = index($layout, 'T');
    if(not defined $inputs)
    {
        $inputs = [
            map { AI::MXNet::Symbol->Variable("${input_prefix}t${_}_data") } (0..$length-1)
        ];
    }



( run in 2.172 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )