AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Module/Bucketing.pm  view on Meta::CPAN

package AI::MXNet::Module::Bucketing;
use Mouse;
use AI::MXNet::Function::Parameters;
use AI::MXNet::Base;

=encoding UTF-8

=head1 NAME

AI::MXNet::Module::Bucketing

=head1 SYNOPSIS

    my $buckets = [10, 20, 30, 40, 50, 60];
    my $start_label   = 1;
    my $invalid_label = 0;

    my ($train_sentences, $vocabulary) = tokenize_text(
        './data/ptb.train.txt', start_label => $start_label,
        invalid_label => $invalid_label
    );
    my ($validation_sentences) = tokenize_text(
        './data/ptb.test.txt', vocab => $vocabulary,
        start_label => $start_label, invalid_label => $invalid_label
    );
    my $data_train  = mx->rnn->BucketSentenceIter(
        $train_sentences, $batch_size, buckets => $buckets,
        invalid_label => $invalid_label
    );
    my $data_val    = mx->rnn->BucketSentenceIter(
        $validation_sentences, $batch_size, buckets => $buckets,
        invalid_label => $invalid_label
    );

    my $stack = mx->rnn->SequentialRNNCell();
    for my $i (0..$num_layers-1)
    {
        $stack->add(mx->rnn->LSTMCell(num_hidden => $num_hidden, prefix => "lstm_l${i}_"));
    }

    my $sym_gen = sub {
        my $seq_len = shift;
        my $data  = mx->sym->Variable('data');
        my $label = mx->sym->Variable('softmax_label');
        my $embed = mx->sym->Embedding(
            data => $data, input_dim => scalar(keys %$vocabulary),
            output_dim => $num_embed, name => 'embed'
        );
        $stack->reset;
        my ($outputs, $states) = $stack->unroll($seq_len, inputs => $embed, merge_outputs => 1);
        my $pred = mx->sym->Reshape($outputs, shape => [-1, $num_hidden]);
        $pred    = mx->sym->FullyConnected(data => $pred, num_hidden => scalar(keys %$vocabulary), name => 'pred');
        $label   = mx->sym->Reshape($label, shape => [-1]);
        $pred    = mx->sym->SoftmaxOutput(data => $pred, label => $label, name => 'softmax');
        return ($pred, ['data'], ['softmax_label']);
    };

    my $contexts;
    if(defined $gpus)
    {
        $contexts = [map { mx->gpu($_) } split(/,/, $gpus)];
    }
    else
    {
        $contexts = mx->cpu(0);
    }



( run in 0.480 second using v1.01-cache-2.11-cpan-39bf76dae61 )