AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/RNN/Cell.pm  view on Meta::CPAN

=cut

=head1 DESCRIPTION

    Apply the dropout on base cell
=cut

method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
{
    my ($output, $states) = &{$self->base_cell}($inputs, $states);
    if($self->dropout_outputs > 0)
    {
        $output = AI::MXNet::Symbol->Dropout(data => $output, p => $self->dropout_outputs);
    }
    if($self->dropout_states > 0)
    {
        $states = [map { AI::MXNet::Symbol->Dropout(data => $_, p => $self->dropout_states) } @{ $states }];
    }
    return ($output, $states);
}

package AI::MXNet::RNN::ZoneoutCell;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::RNN::ModifierCell';
has [qw/zoneout_outputs zoneout_states/] => (is => 'ro', isa => 'Num', default => 0);
has 'prev_output' => (is => 'rw', init_arg => undef);

=head1 NAME

    AI::MXNet::RNN::ZoneoutCell
=cut

=head1 DESCRIPTION

    Apply Zoneout on base cell.
=cut

sub BUILD
{
    my $self = shift;
    assert(
        (not $self->base_cell->isa('AI::MXNet::RNN::FusedCell')),
        "FusedRNNCell doesn't support zoneout. ".
        "Please unfuse first."
    );
    assert(
        (not $self->base_cell->isa('AI::MXNet::RNN::BidirectionalCell')),
        "BidirectionalCell doesn't support zoneout since it doesn't support step. ".
        "Please add ZoneoutCell to the cells underneath instead."
    );
    assert(
        (not $self->base_cell->isa('AI::MXNet::RNN::SequentialCell') or not $self->_bidirectional),
        "Bidirectional SequentialCell doesn't support zoneout. ".
        "Please add ZoneoutCell to the cells underneath instead."
    );
}

method reset()
{
    $self->SUPER::reset;
    $self->prev_output(undef);
}

method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
{
    my ($cell, $p_outputs, $p_states) = ($self->base_cell, $self->zoneout_outputs, $self->zoneout_states);
    my ($next_output, $next_states) = &{$cell}($inputs, $states);
    my $mask = sub {
        my ($p, $like) = @_;
        AI::MXNet::Symbol->Dropout(
            AI::MXNet::Symbol->ones_like(
                $like
            ),
            p => $p
        );
    };
    my $prev_output = $self->prev_output // AI::MXNet::Symbol->zeros(shape => [0, 0]);
    my $output = $p_outputs != 0
        ? AI::MXNet::Symbol->where(
            &{$mask}($p_outputs, $next_output),
            $next_output,
            $prev_output
        )
        : $next_output;
    my @states;
    if($p_states != 0)
    {
        zip(sub {
            my ($new_s, $old_s) = @_;
            push @states, AI::MXNet::Symbol->where(
                &{$mask}($p_states, $new_s),
                $new_s,
                $old_s
            );
        }, $next_states, $states);
    }
    $self->prev_output($output);
    return ($output, @states ? \@states : $next_states);
}

package AI::MXNet::RNN::ResidualCell;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::RNN::ModifierCell';

=head1 NAME

    AI::MXNet::RNN::ResidualCell
=cut

=head1 DESCRIPTION

    Adds residual connection as described in Wu et al, 2016
    (https://arxiv.org/abs/1609.08144).
    Output of the cell is output of the base cell plus input.
=cut

method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
{
    my $output;



( run in 0.711 second using v1.01-cache-2.11-cpan-ceb78f64989 )