AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Module/Bucketing.pm  view on Meta::CPAN

package AI::MXNet::Module::Bucketing;
use Mouse;
use AI::MXNet::Function::Parameters;
use AI::MXNet::Base;

=encoding UTF-8

=head1 NAME

AI::MXNet::Module::Bucketing

=head1 SYNOPSIS

    my $buckets = [10, 20, 30, 40, 50, 60];
    my $start_label   = 1;
    my $invalid_label = 0;

    my ($train_sentences, $vocabulary) = tokenize_text(
        './data/ptb.train.txt', start_label => $start_label,
        invalid_label => $invalid_label
    );
    my ($validation_sentences) = tokenize_text(
        './data/ptb.test.txt', vocab => $vocabulary,
        start_label => $start_label, invalid_label => $invalid_label
    );
    my $data_train  = mx->rnn->BucketSentenceIter(
        $train_sentences, $batch_size, buckets => $buckets,
        invalid_label => $invalid_label
    );
    my $data_val    = mx->rnn->BucketSentenceIter(
        $validation_sentences, $batch_size, buckets => $buckets,
        invalid_label => $invalid_label
    );

    my $stack = mx->rnn->SequentialRNNCell();
    for my $i (0..$num_layers-1)
    {
        $stack->add(mx->rnn->LSTMCell(num_hidden => $num_hidden, prefix => "lstm_l${i}_"));
    }

    my $sym_gen = sub {
        my $seq_len = shift;
        my $data  = mx->sym->Variable('data');
        my $label = mx->sym->Variable('softmax_label');
        my $embed = mx->sym->Embedding(
            data => $data, input_dim => scalar(keys %$vocabulary),
            output_dim => $num_embed, name => 'embed'
        );
        $stack->reset;
        my ($outputs, $states) = $stack->unroll($seq_len, inputs => $embed, merge_outputs => 1);
        my $pred = mx->sym->Reshape($outputs, shape => [-1, $num_hidden]);
        $pred    = mx->sym->FullyConnected(data => $pred, num_hidden => scalar(keys %$vocabulary), name => 'pred');
        $label   = mx->sym->Reshape($label, shape => [-1]);
        $pred    = mx->sym->SoftmaxOutput(data => $pred, label => $label, name => 'softmax');
        return ($pred, ['data'], ['softmax_label']);
    };

    my $contexts;
    if(defined $gpus)
    {
        $contexts = [map { mx->gpu($_) } split(/,/, $gpus)];
    }
    else
    {
        $contexts = mx->cpu(0);
    }

    my $model = mx->mod->BucketingModule(
        sym_gen             => $sym_gen,
        default_bucket_key  => $data_train->default_bucket_key,
        context             => $contexts
    );

    $model->fit(
        $data_train,
        eval_data           => $data_val,
        eval_metric         => mx->metric->Perplexity($invalid_label),
        kvstore             => $kv_store,
        optimizer           => $optimizer,
        optimizer_params    => {
                                    learning_rate => $lr,
                                    momentum      => $mom,
                                    wd            => $wd,
                            },
        initializer         => mx->init->Xavier(factor_type => "in", magnitude => 2.34),
        num_epoch           => $num_epoch,
        batch_end_callback  => mx->callback->Speedometer($batch_size, $disp_batches),
        ($chkp_epoch ? (epoch_end_callback  => mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch)) : ())
    );

=head1 DESCRIPTION

    Implements the AI::MXNet::Module::Base API, and allows multiple
    symbols to be used depending on the `bucket_key` provided by each different
    mini-batch of data
=cut


=head2 new

    Parameters
    ----------
    $sym_gen : subref or any perl object that overloads &{} op
        A sub when called with a bucket key, returns a list with triple
        of ($symbol, $data_names, $label_names).
    $default_bucket_key : str or anything else
        The key for the default bucket.
    $logger : Logger
    $context : AI::MXNet::Context or array ref of AI::MXNet::Context objects
        Default is cpu(0)
    $work_load_list : array ref of Num
        Default is undef, indicating uniform workload.
    $fixed_param_names: arrayref of str
        Default is undef, indicating no network parameters are fixed.
    $state_names : arrayref of str
        states are similar to data and label, but not provided by data iterator.
        Instead they are initialized to 0 and can be set by set_states()
=cut

extends 'AI::MXNet::Module::Base';
has '_sym_gen'            => (is => 'ro', init_arg => 'sym_gen', required => 1);
has '_default_bucket_key' => (is => 'rw', init_arg => 'default_bucket_key', required => 1);
has '_context'            => (
    is => 'ro', isa => 'AI::MXNet::Context|ArrayRef[AI::MXNet::Context]',
    lazy => 1, default => sub { AI::MXNet::Context->cpu },
    init_arg => 'context'
);
has '_work_load_list'     => (is => 'rw', init_arg => 'work_load_list', isa => 'ArrayRef[Num]');
has '_curr_module'        => (is => 'rw', init_arg => undef);
has '_curr_bucket_key'    => (is => 'rw', init_arg => undef);
has '_buckets'            => (is => 'rw', init_arg => undef, default => sub { +{} });
has '_fixed_param_names'  => (is => 'rw', isa => 'ArrayRef[Str]', init_arg => 'fixed_param_names');
has '_state_names'        => (is => 'rw', isa => 'ArrayRef[Str]', init_arg => 'state_names');
has '_params_dirty'       => (is => 'rw', init_arg => undef);

sub BUILD
{
    my ($self, $original_params) = @_;
    $self->_fixed_param_names([]) unless defined $original_params->{fixed_param_names};
    $self->_state_names([]) unless defined $original_params->{state_names};
    $self->_params_dirty(0);
    my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($self->_default_bucket_key);
    $self->_check_input_names($symbol, $data_names//[], "data", 1);
    $self->_check_input_names($symbol, $label_names//[], "label", 0);
    $self->_check_input_names($symbol, $self->_state_names, "state", 1);
    $self->_check_input_names($symbol, $self->_fixed_param_names, "fixed_param", 1);
}

method _reset_bind()
{
    $self->binded(0);
    $self->_buckets({});
    $self->_curr_module(undef);
    $self->_curr_bucket_key(undef);
}

method data_names()
{
    if($self->binded)
    {
        return $self->_curr_module->data_names;
    }
    else
    {
        return (&{$self->_sym_gen}($self->_default_bucket_key))[1];
    }
}

method output_names()
{
    if($self->binded)
    {
        return $self->_curr_module->ouput_names;
    }
    else
    {
        my ($symbol) = &{$self->_sym_gen}($self->_default_bucket_key);
        return $symbol->list_ouputs;
    }
}

method data_shapes()
{
    assert($self->binded);
    return $self->_curr_module->data_shapes;
}

method label_shapes()
{
    assert($self->binded);
    return $self->_curr_module->label_shapes;
}

method output_shapes()
{
    assert($self->binded);



( run in 0.604 second using v1.01-cache-2.11-cpan-39bf76dae61 )