AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/RNN/Cell.pm view on Meta::CPAN
method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states)
{
confess("Not Implemented");
}
method _gate_names()
{
[''];
}
=head2 params
Parameters of this cell
=cut
method params()
{
$self->_own_params(0);
return $self->_params;
}
=head2 state_shape
shape(s) of states
=cut
method state_shape()
{
return [map { $_->{shape} } @{ $self->state_info }];
}
=head2 state_info
shape and layout information of states
=cut
method state_info()
{
confess("Not Implemented");
}
=head2 begin_state
Initial state for this cell.
Parameters
----------
:$func : sub ref, default is AI::MXNet::Symbol->can('zeros')
Function for creating initial state.
Can be AI::MXNet::Symbol->can('zeros'),
AI::MXNet::Symbol->can('uniform'), AI::MXNet::Symbol->can('Variable') etc.
Use AI::MXNet::Symbol->can('Variable') if you want to directly
feed the input as states.
@kwargs :
more keyword arguments passed to func. For example
mean, std, dtype, etc.
Returns
-------
$states : ArrayRef[AI::MXNet::Symbol]
starting states for first RNN step
=cut
method begin_state(CodeRef :$func=AI::MXNet::Symbol->can('zeros'), @kwargs)
{
assert(
(not $self->_modified),
"After applying modifier cells (e.g. DropoutCell) the base "
."cell cannot be called directly. Call the modifier cell instead."
);
my @states;
my $func_needs_named_name = $func ne AI::MXNet::Symbol->can('Variable');
for my $info (@{ $self->state_info })
{
$self->_init_counter($self->_init_counter + 1);
my @name = (sprintf("%sbegin_state_%d", $self->_prefix, $self->_init_counter));
my %info = %{ $info//{} };
if($func_needs_named_name)
{
unshift(@name, 'name');
}
else
{
if(exists $info{__layout__})
{
$info{kwargs} = { __layout__ => delete $info{__layout__} };
}
}
my %kwargs = (@kwargs, %info);
my $state = &{$func}(
'AI::MXNet::Symbol',
@name,
%kwargs
);
push @states, $state;
}
return \@states;
}
=head2 unpack_weights
Unpack fused weight matrices into separate
weight matrices
Parameters
----------
$args : HashRef[AI::MXNet::NDArray]
hash ref containing packed weights.
usually from AI::MXNet::Module->get_output()
Returns
-------
$args : HashRef[AI::MXNet::NDArray]
hash ref with weights associated with
this cell, unpacked.
=cut
method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
{
my %args = %{ $args };
my $h = $self->_num_hidden;
( run in 2.440 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )