AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/RNN/Cell.pm view on Meta::CPAN
=cut
=head1 DESCRIPTION
Apply the dropout on base cell
=cut
method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
{
my ($output, $states) = &{$self->base_cell}($inputs, $states);
if($self->dropout_outputs > 0)
{
$output = AI::MXNet::Symbol->Dropout(data => $output, p => $self->dropout_outputs);
}
if($self->dropout_states > 0)
{
$states = [map { AI::MXNet::Symbol->Dropout(data => $_, p => $self->dropout_states) } @{ $states }];
}
return ($output, $states);
}
package AI::MXNet::RNN::ZoneoutCell;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::RNN::ModifierCell';
has [qw/zoneout_outputs zoneout_states/] => (is => 'ro', isa => 'Num', default => 0);
has 'prev_output' => (is => 'rw', init_arg => undef);
=head1 NAME
AI::MXNet::RNN::ZoneoutCell
=cut
=head1 DESCRIPTION
Apply Zoneout on base cell.
=cut
sub BUILD
{
my $self = shift;
assert(
(not $self->base_cell->isa('AI::MXNet::RNN::FusedCell')),
"FusedRNNCell doesn't support zoneout. ".
"Please unfuse first."
);
assert(
(not $self->base_cell->isa('AI::MXNet::RNN::BidirectionalCell')),
"BidirectionalCell doesn't support zoneout since it doesn't support step. ".
"Please add ZoneoutCell to the cells underneath instead."
);
assert(
(not $self->base_cell->isa('AI::MXNet::RNN::SequentialCell') or not $self->_bidirectional),
"Bidirectional SequentialCell doesn't support zoneout. ".
"Please add ZoneoutCell to the cells underneath instead."
);
}
method reset()
{
$self->SUPER::reset;
$self->prev_output(undef);
}
method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
{
my ($cell, $p_outputs, $p_states) = ($self->base_cell, $self->zoneout_outputs, $self->zoneout_states);
my ($next_output, $next_states) = &{$cell}($inputs, $states);
my $mask = sub {
my ($p, $like) = @_;
AI::MXNet::Symbol->Dropout(
AI::MXNet::Symbol->ones_like(
$like
),
p => $p
);
};
my $prev_output = $self->prev_output // AI::MXNet::Symbol->zeros(shape => [0, 0]);
my $output = $p_outputs != 0
? AI::MXNet::Symbol->where(
&{$mask}($p_outputs, $next_output),
$next_output,
$prev_output
)
: $next_output;
my @states;
if($p_states != 0)
{
zip(sub {
my ($new_s, $old_s) = @_;
push @states, AI::MXNet::Symbol->where(
&{$mask}($p_states, $new_s),
$new_s,
$old_s
);
}, $next_states, $states);
}
$self->prev_output($output);
return ($output, @states ? \@states : $next_states);
}
package AI::MXNet::RNN::ResidualCell;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::RNN::ModifierCell';
=head1 NAME
AI::MXNet::RNN::ResidualCell
=cut
=head1 DESCRIPTION
Adds residual connection as described in Wu et al, 2016
(https://arxiv.org/abs/1609.08144).
Output of the cell is output of the base cell plus input.
=cut
method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
{
my $output;
( run in 0.711 second using v1.01-cache-2.11-cpan-ceb78f64989 )