AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/RNN/Cell.pm  view on Meta::CPAN

    else
    {
        return &{$activation}($inputs, @kwargs);
    }
}

method _cells_state_shape($cells)
{
    return [map { @{ $_->state_shape } } @$cells];
}

method _cells_state_info($cells)
{
    return [map { @{ $_->state_info } } @$cells];
}

method _cells_begin_state($cells, @kwargs)
{
    return [map { @{ $_->begin_state(@kwargs) } } @$cells];
}

method _cells_unpack_weights($cells, $args)
{
    $args = $_->unpack_weights($args) for @$cells;
    return $args;
}

method _cells_pack_weights($cells, $args)
{
    $args = $_->pack_weights($args) for @$cells;
    return $args;
}

package AI::MXNet::RNN::Cell;
use Mouse;
extends 'AI::MXNet::RNN::Cell::Base';

=head1 NAME

    AI::MXNet::RNN::Cell
=cut

=head1 DESCRIPTION

    Simple recurrent neural network cell

    Parameters
    ----------
    num_hidden : int
        number of units in output symbol
    activation : str or Symbol, default 'tanh'
        type of activation function
    prefix : str, default 'rnn_'
        prefix for name of layers
        (and name of weight if params is undef)
    params : AI::MXNet::RNNParams or undef
        container for weight sharing between cells.
        created if undef.
=cut

has '_num_hidden'  => (is => 'ro', init_arg => 'num_hidden', isa => 'Int', required => 1);
has 'forget_bias'  => (is => 'ro', isa => 'Num');
has '_activation'  => (
    is       => 'ro',
    init_arg => 'activation',
    isa      => 'Activation',
    default  => 'tanh'
);
has '+_prefix'    => (default => 'rnn_');
has [qw/_iW _iB
        _hW _hB/] => (is => 'rw', init_arg => undef);

around BUILDARGS => sub {
    my $orig  = shift;
    my $class = shift;
    return $class->$orig(num_hidden => $_[0]) if @_ == 1;
    return $class->$orig(@_);
};

sub BUILD
{
    my $self = shift;
    $self->_iW($self->params->get('i2h_weight'));
    $self->_iB(
        $self->params->get(
            'i2h_bias',
            (defined($self->forget_bias)
                ? (init => AI::MXNet::LSTMBias->new(forget_bias => $self->forget_bias))
                : ()
            )
        )
    );
    $self->_hW($self->params->get('h2h_weight'));
    $self->_hB($self->params->get('h2h_bias'));
}

method state_info()
{
    return [{ shape => [0, $self->_num_hidden], __layout__ => 'NC' }];
}

method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
{
    $self->_counter($self->_counter + 1);
    my $name = sprintf('%st%d_', $self->_prefix, $self->_counter);
    my $i2h = AI::MXNet::Symbol->FullyConnected(
        data       => $inputs,
        weight     => $self->_iW,
        bias       => $self->_iB,
        num_hidden => $self->_num_hidden,
        name       => "${name}i2h"
    );
    my $h2h = AI::MXNet::Symbol->FullyConnected(
        data       => @{$states}[0],
        weight     => $self->_hW,
        bias       => $self->_hB,
        num_hidden => $self->_num_hidden,
        name       => "${name}h2h"
    );
    my $output = $self->_get_activation(
        $i2h + $h2h,

lib/AI/MXNet/RNN/Cell.pm  view on Meta::CPAN

{
    $self->_counter($self->_counter + 1);
    my $name = sprintf('%st%d_', $self->_prefix, $self->_counter);
    my $prev_state_h = @{ $states }[0];
    my $i2h = AI::MXNet::Symbol->FullyConnected(
        data       => $inputs,
        weight     => $self->_iW,
        bias       => $self->_iB,
        num_hidden => $self->_num_hidden*3,
        name       => "${name}i2h"
    );
    my $h2h = AI::MXNet::Symbol->FullyConnected(
        data       => $prev_state_h,
        weight     => $self->_hW,
        bias       => $self->_hB,
        num_hidden => $self->_num_hidden*3,
        name       => "${name}h2h"
    );
    my ($i2h_r, $i2h_z);
    ($i2h_r, $i2h_z, $i2h) = @{ AI::MXNet::Symbol->SliceChannel(
        $i2h, num_outputs => 3, name => "${name}_i2h_slice"
    ) };
    my ($h2h_r, $h2h_z);
    ($h2h_r, $h2h_z, $h2h) = @{ AI::MXNet::Symbol->SliceChannel(
        $h2h, num_outputs => 3, name => "${name}_h2h_slice"
    ) };
    my $reset_gate = AI::MXNet::Symbol->Activation(
        $i2h_r + $h2h_r, act_type => "sigmoid", name => "${name}_r_act"
    );
    my $update_gate = AI::MXNet::Symbol->Activation(
        $i2h_z + $h2h_z, act_type => "sigmoid", name => "${name}_z_act"
    );
    my $next_h_tmp = AI::MXNet::Symbol->Activation(
        $i2h + $reset_gate * $h2h, act_type => "tanh", name => "${name}_h_act"
    );
    my $next_h = AI::MXNet::Symbol->_plus(
        (1 - $update_gate) * $next_h_tmp, $update_gate * $prev_state_h,
        name => "${name}out"
    );
    return ($next_h, [$next_h]);
}

package AI::MXNet::RNN::FusedCell;
use Mouse;
use AI::MXNet::Types;
use AI::MXNet::Base;
extends 'AI::MXNet::RNN::Cell::Base';

=head1 NAME

    AI::MXNet::RNN::FusedCell
=cut

=head1 DESCRIPTION

    Fusing RNN layers across time step into one kernel.
    Improves speed but is less flexible. Currently only
    supported if using cuDNN on GPU.
=cut

has '_num_hidden'      => (is => 'ro', isa => 'Int',  init_arg => 'num_hidden',     required => 1);
has '_num_layers'      => (is => 'ro', isa => 'Int',  init_arg => 'num_layers',     default => 1);
has '_dropout'         => (is => 'ro', isa => 'Num',  init_arg => 'dropout',        default => 0);
has '_get_next_state'  => (is => 'ro', isa => 'Bool', init_arg => 'get_next_state', default => 0);
has '_bidirectional'   => (is => 'ro', isa => 'Bool', init_arg => 'bidirectional',  default => 0);
has 'forget_bias'      => (is => 'ro', isa => 'Num',  default => 1);
has 'initializer'      => (is => 'rw', isa => 'Maybe[Initializer]');
has '_mode'            => (
    is => 'ro',
    isa => enum([qw/rnn_relu rnn_tanh lstm gru/]),
    init_arg => 'mode',
    default => 'lstm'
);
has [qw/_parameter
        _directions/] => (is => 'rw', init_arg => undef);

around BUILDARGS => sub {
    my $orig  = shift;
    my $class = shift;
    return $class->$orig(num_hidden => $_[0]) if @_ == 1;
    return $class->$orig(@_);
};

sub BUILD
{
    my $self = shift;
    if(not $self->_prefix)
    {
        $self->_prefix($self->_mode.'_');
    }
    if(not defined $self->initializer)
    {
        $self->initializer(
            AI::MXNet::Xavier->new(
                factor_type => 'in',
                magnitude   => 2.34
            )
        );
    }
    if(not $self->initializer->isa('AI::MXNet::FusedRNN'))
    {
        $self->initializer(
            AI::MXNet::FusedRNN->new(
                init           => $self->initializer,
                num_hidden     => $self->_num_hidden,
                num_layers     => $self->_num_layers,
                mode           => $self->_mode,
                bidirectional  => $self->_bidirectional,
                forget_bias    => $self->forget_bias
            )
        );
    }
    $self->_parameter($self->params->get('parameters', init => $self->initializer));
    $self->_directions($self->_bidirectional ? [qw/l r/] : ['l']);
}


method state_info()
{
    my $b = @{ $self->_directions };
    my $n = $self->_mode eq 'lstm' ? 2 : 1;

lib/AI/MXNet/RNN/Cell.pm  view on Meta::CPAN

    }
    return ($inputs, [map { @$_} @next_states]);
}

method unroll(
    Int $length,
    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
    Str                                                  :$input_prefix='',
    Str                                                  :$layout='NTC',
    Maybe[Bool]                                          :$merge_outputs=
)
{
    my $num_cells = @{ $self->_cells };
    $begin_state //= $self->begin_state;
    my $p = 0;
    my $states;
    my @next_states;
    enumerate(sub {
        my ($i, $cell) = @_;
        my $n   = @{ $cell->state_info };
        $states = [@{$begin_state}[$p..$p+$n-1]];
        $p += $n;
        ($inputs, $states) = $cell->unroll(
            $length,
            inputs          => $inputs,
            input_prefix    => $input_prefix,
            begin_state     => $states,
            layout          => $layout,
            merge_outputs   => ($i < $num_cells-1) ? undef : $merge_outputs
        );
        push @next_states, $states;
    }, $self->_cells);
    return ($inputs, [map { @{ $_ } } @next_states]);
}

package AI::MXNet::RNN::BidirectionalCell;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::RNN::Cell::Base';

=head1 NAME

    AI::MXNet::RNN::BidirectionalCell
=cut

=head1 DESCRIPTION

    Bidirectional RNN cell

    Parameters
    ----------
    l_cell : AI::MXNet::RNN::Cell::Base
        cell for forward unrolling
    r_cell : AI::MXNet::RNN::Cell::Base
        cell for backward unrolling
    output_prefix : str, default 'bi_'
        prefix for name of output
=cut

has 'l_cell'         => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1);
has 'r_cell'         => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1);
has '_output_prefix' => (is => 'ro', init_arg => 'output_prefix', isa => 'Str', default => 'bi_');
has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef);

around BUILDARGS => sub {
    my $orig  = shift;
    my $class = shift;
    if(@_ >= 2 and blessed $_[0] and blessed $_[1])
    {
        my $l_cell = shift(@_);
        my $r_cell = shift(@_);
        return $class->$orig(
            l_cell => $l_cell,
            r_cell => $r_cell,
            @_
        );
    }
    return $class->$orig(@_);
};

sub BUILD
{
    my ($self, $original_arguments) = @_;
    $self->_override_cell_params(defined $original_arguments->{params});
    if($self->_override_cell_params)
    {
        assert(
            ($self->l_cell->_own_params and $self->r_cell->_own_params),
            "Either specify params for BidirectionalCell ".
            "or child cells, not both."
        );
        %{ $self->l_cell->params->_params } = (%{ $self->l_cell->params->_params }, %{ $self->params->_params });
        %{ $self->r_cell->params->_params } = (%{ $self->r_cell->params->_params }, %{ $self->params->_params });
    }
    %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->l_cell->params->_params });
    %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->r_cell->params->_params });
    $self->_cells([$self->l_cell, $self->r_cell]);
}

method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
{
    return $self->_cells_unpack_weights($self->_cells, $args)
}

method pack_weights(HashRef[AI::MXNet::NDArray] $args)
{
    return $self->_cells_pack_weights($self->_cells, $args);
}

method call($inputs, $states)
{
    confess("Bidirectional cannot be stepped. Please use unroll");
}

method state_info()
{
    return $self->_cells_state_info($self->_cells);
}

method begin_state(@kwargs)
{

lib/AI/MXNet/RNN/Cell.pm  view on Meta::CPAN

=head1 NAME

    AI::MXNet::RNN::ConvGRUCell
=cut

=head1 DESCRIPTION

    Convolutional GRU network cell.
=cut

method _gate_names()
{
    return ['_r', '_z', '_o'];
}

method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states)
{
    $self->_counter($self->_counter + 1);
    my $name = sprintf('%st%d_', $self->_prefix, $self->_counter);
    my ($i2h, $h2h) = $self->_conv_forward($inputs, $states, $name);
    my ($i2h_r, $i2h_z, $h2h_r, $h2h_z);
    ($i2h_r, $i2h_z, $i2h) = @{ AI::MXNet::Symbol->SliceChannel($i2h, num_outputs => 3, name => "${name}_i2h_slice") };
    ($h2h_r, $h2h_z, $h2h) = @{ AI::MXNet::Symbol->SliceChannel($h2h, num_outputs => 3, name => "${name}_h2h_slice") };
    my $reset_gate = AI::MXNet::Symbol->Activation(
        $i2h_r + $h2h_r, act_type => "sigmoid",
        name => "${name}_r_act"
    );
    my $update_gate = AI::MXNet::Symbol->Activation(
        $i2h_z + $h2h_z, act_type => "sigmoid",
        name => "${name}_z_act"
    );
    my $next_h_tmp = $self->_get_activation($i2h + $reset_gate * $h2h, $self->_activation, name => "${name}_h_act");
    my $next_h = AI::MXNet::Symbol->_plus(
        (1 - $update_gate) * $next_h_tmp, $update_gate * @{$states}[0],
        name => "${name}out"
    );
    return ($next_h, [$next_h]);
}

package AI::MXNet::RNN::ModifierCell;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::RNN::Cell::Base';

=head1 NAME

    AI::MXNet::RNN::ModifierCell
=cut

=head1 DESCRIPTION

    Base class for modifier cells. A modifier
    cell takes a base cell, apply modifications
    on it (e.g. Dropout), and returns a new cell.

    After applying modifiers the base cell should
    no longer be called directly. The modifer cell
    should be used instead.
=cut

has 'base_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1);

around BUILDARGS => sub {
    my $orig  = shift;
    my $class = shift;
    if(@_%2)
    {
        my $base_cell = shift;
        return $class->$orig(base_cell => $base_cell, @_);
    }
    return $class->$orig(@_);
};

sub BUILD
{
    my $self = shift;
    $self->base_cell->_modified(1);
}

method params()
{
    $self->_own_params(0);
    return $self->base_cell->params;
}

method state_info()
{
    return $self->base_cell->state_info;
}

method begin_state(CodeRef :$init_sym=AI::MXNet::Symbol->can('zeros'), @kwargs)
{
    assert(
        (not $self->_modified),
        "After applying modifier cells (e.g. DropoutCell) the base "
        ."cell cannot be called directly. Call the modifier cell instead."
    );
    $self->base_cell->_modified(0);
    my $begin_state = $self->base_cell->begin_state(func => $init_sym, @kwargs);
    $self->base_cell->_modified(1);
    return $begin_state;
}

method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
{
    return $self->base_cell->unpack_weights($args)
}

method pack_weights(HashRef[AI::MXNet::NDArray] $args)
{
    return $self->base_cell->pack_weights($args)
}

method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
{
    confess("Not Implemented");
}

package AI::MXNet::RNN::DropoutCell;
use Mouse;
extends 'AI::MXNet::RNN::ModifierCell';



( run in 1.932 second using v1.01-cache-2.11-cpan-39bf76dae61 )