AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/RNN/Cell.pm view on Meta::CPAN
AI::MXNet::RNN::Cell->new(
num_hidden => $self->_num_hidden,
activation => 'tanh',
prefix => shift
)
},
lstm => sub {
AI::MXNet::RNN::LSTMCell->new(
num_hidden => $self->_num_hidden,
prefix => shift
)
},
gru => sub {
AI::MXNet::RNN::GRUCell->new(
num_hidden => $self->_num_hidden,
prefix => shift
)
},
}->{ $self->_mode };
for my $i (0..$self->_num_layers-1)
{
if($self->_bidirectional)
{
$stack->add(
AI::MXNet::RNN::BidirectionalCell->new(
$get_cell->(sprintf('%sl%d_', $self->_prefix, $i)),
$get_cell->(sprintf('%sr%d_', $self->_prefix, $i)),
output_prefix => sprintf('%sbi_%s_%d', $self->_prefix, $self->_mode, $i)
)
);
}
else
{
$stack->add($get_cell->(sprintf('%sl%d_', $self->_prefix, $i)));
}
}
return $stack;
}
package AI::MXNet::RNN::SequentialCell;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::RNN::Cell::Base';
=head1 NAME
AI:MXNet::RNN::SequentialCell
=cut
=head1 DESCRIPTION
Sequentially stacking multiple RNN cells
Parameters
----------
params : AI::MXNet::RNN::Params or undef
container for weight sharing between cells.
created if undef.
=cut
has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef);
sub BUILD
{
my ($self, $original_arguments) = @_;
$self->_override_cell_params(defined $original_arguments->{params});
$self->_cells([]);
}
=head2 add
Append a cell to the stack.
Parameters
----------
$cell : AI::MXNet::RNN::Cell::Base
=cut
method add(AI::MXNet::RNN::Cell::Base $cell)
{
push @{ $self->_cells }, $cell;
if($self->_override_cell_params)
{
assert(
$cell->_own_params,
"Either specify params for SequentialRNNCell "
."or child cells, not both."
);
%{ $cell->params->_params } = (%{ $cell->params->_params }, %{ $self->params->_params });
}
%{ $self->params->_params } = (%{ $self->params->_params }, %{ $cell->params->_params });
}
method state_info()
{
return $self->_cells_state_info($self->_cells);
}
method begin_state(@kwargs)
{
assert(
(not $self->_modified),
"After applying modifier cells (e.g. DropoutCell) the base "
."cell cannot be called directly. Call the modifier cell instead."
);
return $self->_cells_begin_state($self->_cells, @kwargs);
}
method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
{
return $self->_cells_unpack_weights($self->_cells, $args)
}
method pack_weights(HashRef[AI::MXNet::NDArray] $args)
{
return $self->_cells_pack_weights($self->_cells, $args);
}
method call($inputs, $states)
{
$self->_counter($self->_counter + 1);
my @next_states;
my $p = 0;
for my $cell (@{ $self->_cells })
{
assert(not $cell->isa('AI::MXNet::BidirectionalCell'));
my $n = scalar(@{ $cell->state_info });
my $state = [@{ $states }[$p..$p+$n-1]];
$p += $n;
($inputs, $state) = &{$cell}($inputs, $state);
push @next_states, $state;
}
return ($inputs, [map { @$_} @next_states]);
}
method unroll(
Int $length,
Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
Str :$input_prefix='',
Str :$layout='NTC',
Maybe[Bool] :$merge_outputs=
)
{
my $num_cells = @{ $self->_cells };
$begin_state //= $self->begin_state;
my $p = 0;
my $states;
my @next_states;
enumerate(sub {
my ($i, $cell) = @_;
my $n = @{ $cell->state_info };
$states = [@{$begin_state}[$p..$p+$n-1]];
$p += $n;
($inputs, $states) = $cell->unroll(
$length,
inputs => $inputs,
input_prefix => $input_prefix,
begin_state => $states,
layout => $layout,
merge_outputs => ($i < $num_cells-1) ? undef : $merge_outputs
);
push @next_states, $states;
}, $self->_cells);
return ($inputs, [map { @{ $_ } } @next_states]);
}
package AI::MXNet::RNN::BidirectionalCell;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::RNN::Cell::Base';
=head1 NAME
AI::MXNet::RNN::BidirectionalCell
=cut
=head1 DESCRIPTION
Bidirectional RNN cell
Parameters
----------
l_cell : AI::MXNet::RNN::Cell::Base
cell for forward unrolling
r_cell : AI::MXNet::RNN::Cell::Base
cell for backward unrolling
output_prefix : str, default 'bi_'
prefix for name of output
=cut
has 'l_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1);
has 'r_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1);
has '_output_prefix' => (is => 'ro', init_arg => 'output_prefix', isa => 'Str', default => 'bi_');
has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef);
around BUILDARGS => sub {
my $orig = shift;
my $class = shift;
if(@_ >= 2 and blessed $_[0] and blessed $_[1])
{
my $l_cell = shift(@_);
my $r_cell = shift(@_);
return $class->$orig(
l_cell => $l_cell,
r_cell => $r_cell,
@_
);
}
return $class->$orig(@_);
};
sub BUILD
{
my ($self, $original_arguments) = @_;
$self->_override_cell_params(defined $original_arguments->{params});
if($self->_override_cell_params)
{
assert(
($self->l_cell->_own_params and $self->r_cell->_own_params),
"Either specify params for BidirectionalCell ".
"or child cells, not both."
);
%{ $self->l_cell->params->_params } = (%{ $self->l_cell->params->_params }, %{ $self->params->_params });
%{ $self->r_cell->params->_params } = (%{ $self->r_cell->params->_params }, %{ $self->params->_params });
}
%{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->l_cell->params->_params });
%{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->r_cell->params->_params });
$self->_cells([$self->l_cell, $self->r_cell]);
}
method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
{
return $self->_cells_unpack_weights($self->_cells, $args)
}
method pack_weights(HashRef[AI::MXNet::NDArray] $args)
{
return $self->_cells_pack_weights($self->_cells, $args);
}
method call($inputs, $states)
{
confess("Bidirectional cannot be stepped. Please use unroll");
}
method state_info()
{
return $self->_cells_state_info($self->_cells);
}
method begin_state(@kwargs)
{
assert((not $self->_modified),
"After applying modifier cells (e.g. DropoutCell) the base "
."cell cannot be called directly. Call the modifier cell instead."
);
return $self->_cells_begin_state($self->_cells, @kwargs);
}
method unroll(
Int $length,
Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
Str :$input_prefix='',
Str :$layout='NTC',
Maybe[Bool] :$merge_outputs=
)
{
my $axis = index($layout, 'T');
if(not defined $inputs)
{
$inputs = [
map { AI::MXNet::Symbol->Variable("${input_prefix}t${_}_data") } (0..$length-1)
];
}
( run in 2.172 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )