AI-MXNet

 view release on metacpan or  search on metacpan

examples/cudnn_lstm_bucketing.pl  view on Meta::CPAN

#!/usr/bin/perl
use strict;
use warnings;
use AI::MXNet qw(mx);
use AI::MXNet::Function::Parameters;
use AI::MXNet::Base;
use Getopt::Long qw(HelpMessage);

GetOptions(
    'test'            => \(my $do_test                ),
    'num-layers=i'    => \(my $num_layers   = 2       ),
    'num-hidden=i'    => \(my $num_hidden   = 256     ),
    'num-embed=i'     => \(my $num_embed    = 256     ),
    'num-seq=i'       => \(my $seq_size     = 32      ),
    'gpus=s'          => \(my $gpus                   ),
    'kv-store=s'      => \(my $kv_store     = 'device'),
    'num-epoch=i'     => \(my $num_epoch    = 25      ),
    'lr=f'            => \(my $lr           = 0.01    ),
    'optimizer=s'     => \(my $optimizer    = 'adam'  ),
    'mom=f'           => \(my $mom          = 0       ),
    'wd=f'            => \(my $wd           = 0.00001 ),
    'batch-size=i'    => \(my $batch_size   = 32      ),
    'disp-batches=i'  => \(my $disp_batches = 50      ),
    'model-prefix=s'  => \(my $model_prefix = 'lstm_' ),
    'load-epoch=i'    => \(my $load_epoch   = 0       ),
    'stack-rnn'       => \(my $stack_rnn              ),
    'bidirectional=i' => \(my $bidirectional          ),
    'dropout=f',      => \(my $dropout      = 0       ),
    'help'           => sub { HelpMessage(0) },
) or HelpMessage(1);

=head1 NAME

    char_lstm.pl - Example of training char LSTM RNN on tiny shakespeare using high level RNN interface

=head1 SYNOPSIS

    --test           Whether to test or train (default 0)
    --num-layers     number of stacked RNN layers, default=2
    --num-hidden     hidden layer size, default=200
    --num-seq        sequence size, default=32
    --gpus           list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.
                     Increase batch size when using multiple gpus for best performance.
    --kv-store       key-value store type, default='device'
    --num-epochs     max num of epochs, default=25
    --lr             initial learning rate, default=0.01
    --optimizer      the optimizer type, default='adam'
    --mom            momentum for sgd, default=0.0
    --wd             weight decay for sgd, default=0.00001
    --batch-size     the batch size type, default=32
    --disp-batches   show progress for every n batches, default=50
    --model-prefix   prefix for checkpoint files for loading/saving, default='lstm_'
    --load-epoch     load from epoch
    --stack-rnn      stack rnn to reduce communication overhead (1,0 default 0)
    --bidirectional  whether to use bidirectional layers (1,0 default 0)
    --dropout        dropout probability (1.0 - keep probability), default 0
=cut

$bidirectional = $bidirectional ? 1 : 0;
$stack_rnn     = $stack_rnn     ? 1 : 0;

func tokenize_text($fname, :$vocab=, :$invalid_label=-1, :$start_label=0)
{
    open(F, $fname) or die "Can't open $fname: $!";
    my @lines = map { my $l = [split(/ /)]; shift(@$l); $l } (<F>);
    my $sentences;
    ($sentences, $vocab) = mx->rnn->encode_sentences(
        \@lines,
        vocab         => $vocab,
        invalid_label => $invalid_label,
        start_label   => $start_label
    );
    return ($sentences, $vocab);
}

my $buckets = [10, 20, 30, 40, 50, 60];
my $start_label   = 1;
my $invalid_label = 0;

func get_data($layout)
{
    my ($train_sentences, $vocabulary) = tokenize_text(
        './data/ptb.train.txt', start_label => $start_label,
        invalid_label => $invalid_label
    );
    my ($validation_sentences) = tokenize_text(
        './data/ptb.test.txt', vocab => $vocabulary,
        start_label => $start_label, invalid_label => $invalid_label
    );
    my $data_train  = mx->rnn->BucketSentenceIter(
        $train_sentences, $batch_size, buckets => $buckets,
        invalid_label => $invalid_label,
        layout        => $layout
    );
    my $data_val    = mx->rnn->BucketSentenceIter(
        $validation_sentences, $batch_size, buckets => $buckets,
        invalid_label => $invalid_label,
        layout        => $layout
    );
    return ($data_train, $data_val, $vocabulary);
}

my $train = sub



( run in 1.821 second using v1.01-cache-2.11-cpan-8f98c5d2c55 )