AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/RNN/Cell.pm view on Meta::CPAN
Returns
-------
$outputs : array ref of Symbol or Symbol
output symbols.
$states : Symbol or nested list of Symbol
has the same structure as begin_state()
=cut
method unroll(
Int $length,
Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
Str :$input_prefix='',
Str :$layout='NTC',
Maybe[Bool] :$merge_outputs=
)
{
$self->reset;
my $axis = index($layout, 'T');
if(not defined $inputs)
{
$inputs = [
map { AI::MXNet::Symbol->Variable("${input_prefix}t${_}_data") } (0..$length-1)
];
}
elsif(blessed($inputs))
{
assert(
(@{ $inputs->list_outputs() } == 1),
"unroll doesn't allow grouped symbol as input. Please "
."convert to list first or let unroll handle slicing"
);
$inputs = AI::MXNet::Symbol->SliceChannel(
$inputs,
axis => $axis,
num_outputs => $length,
squeeze_axis => 1
);
}
else
{
assert(@$inputs == $length);
}
$begin_state //= $self->begin_state;
my $states = $begin_state;
my $outputs;
my @inputs = @{ $inputs };
for my $i (0..$length-1)
{
my $output;
($output, $states) = &{$self}(
$inputs[$i],
$states
);
push @$outputs, $output;
}
if($merge_outputs)
{
@$outputs = map { AI::MXNet::Symbol->expand_dims($_, axis => $axis) } @$outputs;
$outputs = AI::MXNet::Symbol->Concat(@$outputs, dim => $axis);
}
return($outputs, $states);
}
method _get_activation($inputs, $activation, @kwargs)
{
if(not ref $activation)
{
return AI::MXNet::Symbol->Activation($inputs, act_type => $activation, @kwargs);
}
else
{
return &{$activation}($inputs, @kwargs);
}
}
method _cells_state_shape($cells)
{
return [map { @{ $_->state_shape } } @$cells];
}
method _cells_state_info($cells)
{
return [map { @{ $_->state_info } } @$cells];
}
method _cells_begin_state($cells, @kwargs)
{
return [map { @{ $_->begin_state(@kwargs) } } @$cells];
}
method _cells_unpack_weights($cells, $args)
{
$args = $_->unpack_weights($args) for @$cells;
return $args;
}
method _cells_pack_weights($cells, $args)
{
$args = $_->pack_weights($args) for @$cells;
return $args;
}
package AI::MXNet::RNN::Cell;
use Mouse;
extends 'AI::MXNet::RNN::Cell::Base';
=head1 NAME
AI::MXNet::RNN::Cell
=cut
=head1 DESCRIPTION
Simple recurrent neural network cell
Parameters
----------
num_hidden : int
lib/AI/MXNet/RNN/Cell.pm view on Meta::CPAN
method pack_weights(HashRef[AI::MXNet::NDArray] $args)
{
my %args = %{ $args };
my $b = @{ $self->_directions };
my $m = $self->_num_gates;
my @c = @{ $self->_gate_names };
my $h = $self->_num_hidden;
my $w0 = $args{ sprintf('%sl0_i2h%s_weight', $self->_prefix, $c[0]) };
my $num_input = $w0->shape->[1];
my $total = ($num_input+$h+2)*$h*$m*$b + ($self->_num_layers-1)*$m*$h*($h+$b*$h+2)*$b;
my $arr = AI::MXNet::NDArray->zeros([$total], ctx => $w0->context, dtype => $w0->dtype);
my %nargs = $self->_slice_weights($arr, $num_input, $h);
while(my ($name, $nd) = each %nargs)
{
$nd .= delete $args{ $name };
}
$args{ $self->_parameter->name } = $arr;
return \%args;
}
method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
{
confess("AI::MXNet::RNN::FusedCell cannot be stepped. Please use unroll");
}
method unroll(
Int $length,
Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
Str :$input_prefix='',
Str :$layout='NTC',
Maybe[Bool] :$merge_outputs=
)
{
$self->reset;
my $axis = index($layout, 'T');
$inputs //= AI::MXNet::Symbol->Variable("${input_prefix}data");
if(blessed($inputs))
{
assert(
(@{ $inputs->list_outputs() } == 1),
"unroll doesn't allow grouped symbol as input. Please "
."convert to list first or let unroll handle slicing"
);
if($axis == 1)
{
AI::MXNet::Logging->warning(
"NTC layout detected. Consider using "
."TNC for RNN::FusedCell for faster speed"
);
$inputs = AI::MXNet::Symbol->SwapAxis($inputs, dim1 => 0, dim2 => 1);
}
else
{
assert($axis == 0, "Unsupported layout $layout");
}
}
else
{
assert(@$inputs == $length);
$inputs = [map { AI::MXNet::Symbol->expand_dims($_, axis => 0) } @{ $inputs }];
$inputs = AI::MXNet::Symbol->Concat(@{ $inputs }, dim => 0);
}
$begin_state //= $self->begin_state;
my $states = $begin_state;
my @states = @{ $states };
my %states;
if($self->_mode eq 'lstm')
{
%states = (state => $states[0], state_cell => $states[1]);
}
else
{
%states = (state => $states[0]);
}
my $rnn = AI::MXNet::Symbol->RNN(
data => $inputs,
parameters => $self->_parameter,
state_size => $self->_num_hidden,
num_layers => $self->_num_layers,
bidirectional => $self->_bidirectional,
p => $self->_dropout,
state_outputs => $self->_get_next_state,
mode => $self->_mode,
name => $self->_prefix.'rnn',
%states
);
my $outputs;
my %attr = (__layout__ => 'LNC');
if(not $self->_get_next_state)
{
($outputs, $states) = ($rnn, []);
}
elsif($self->_mode eq 'lstm')
{
my @rnn = @{ $rnn };
$rnn[1]->_set_attr(%attr);
$rnn[2]->_set_attr(%attr);
($outputs, $states) = ($rnn[0], [$rnn[1], $rnn[2]]);
}
else
{
my @rnn = @{ $rnn };
$rnn[1]->_set_attr(%attr);
($outputs, $states) = ($rnn[0], [$rnn[1]]);
}
if(defined $merge_outputs and not $merge_outputs)
{
AI::MXNet::Logging->warning(
"Call RNN::FusedCell->unroll with merge_outputs=1 "
."for faster speed"
);
$outputs = [@ {
AI::MXNet::Symbol->SliceChannel(
$outputs,
axis => 0,
num_outputs => $length,
squeeze_axis => 1
)
}];
}
lib/AI/MXNet/RNN/Cell.pm view on Meta::CPAN
Str :$input_prefix='',
Str :$layout='NTC',
Maybe[Bool] :$merge_outputs=
)
{
$self->reset;
$self->base_cell->_modified(0);
my ($outputs, $states) = $self->base_cell->unroll($length, inputs=>$inputs, begin_state=>$begin_state,
layout=>$layout, merge_outputs=>$merge_outputs);
$self->base_cell->_modified(1);
$merge_outputs //= (blessed($outputs) and $outputs->isa('AI::MXNet::Symbol'));
($inputs) = _normalize_sequence($length, $inputs, $layout, $merge_outputs);
if($merge_outputs)
{
$outputs = AI::MXNet::Symbol->elemwise_add($outputs, $inputs, name => $outputs->name . "_plus_residual");
}
else
{
my @temp;
zip(sub {
my ($output_sym, $input_sym) = @_;
push @temp, AI::MXNet::Symbol->elemwise_add($output_sym, $input_sym,
name=>$output_sym->name."_plus_residual");
}, [@{ $outputs }], [@{ $inputs }]);
$outputs = \@temp;
}
return ($outputs, $states);
}
func _normalize_sequence($length, $inputs, $layout, $merge, $in_layout=)
{
assert((defined $inputs),
"unroll(inputs=>undef) has been deprecated. ".
"Please create input variables outside unroll."
);
my $axis = index($layout, 'T');
my $in_axis = defined $in_layout ? index($in_layout, 'T') : $axis;
if(blessed($inputs))
{
if(not $merge)
{
assert(
(@{ $inputs->list_outputs() } == 1),
"unroll doesn't allow grouped symbol as input. Please "
."convert to list first or let unroll handle splitting"
);
$inputs = [ @{ AI::MXNet::Symbol->split(
$inputs,
axis => $in_axis,
num_outputs => $length,
squeeze_axis => 1
) }];
}
}
else
{
assert(not defined $length or @$inputs == $length);
if($merge)
{
$inputs = [map { AI::MXNet::Symbol->expand_dims($_, axis=>$axis) } @{ $inputs }];
$inputs = AI::MXNet::Symbol->Concat(@{ $inputs }, dim=>$axis);
$in_axis = $axis;
}
}
if(blessed($inputs) and $axis != $in_axis)
{
$inputs = AI::MXNet::Symbol->swapaxes($inputs, dim0=>$axis, dim1=>$in_axis);
}
return ($inputs, $axis);
}
1;
( run in 1.463 second using v1.01-cache-2.11-cpan-97f6503c9c8 )