AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Module/Bucketing.pm view on Meta::CPAN
package AI::MXNet::Module::Bucketing;
use Mouse;
use AI::MXNet::Function::Parameters;
use AI::MXNet::Base;
=encoding UTF-8
=head1 NAME
AI::MXNet::Module::Bucketing
=head1 SYNOPSIS
my $buckets = [10, 20, 30, 40, 50, 60];
my $start_label = 1;
my $invalid_label = 0;
my ($train_sentences, $vocabulary) = tokenize_text(
'./data/ptb.train.txt', start_label => $start_label,
invalid_label => $invalid_label
);
my ($validation_sentences) = tokenize_text(
'./data/ptb.test.txt', vocab => $vocabulary,
start_label => $start_label, invalid_label => $invalid_label
);
my $data_train = mx->rnn->BucketSentenceIter(
$train_sentences, $batch_size, buckets => $buckets,
invalid_label => $invalid_label
);
my $data_val = mx->rnn->BucketSentenceIter(
$validation_sentences, $batch_size, buckets => $buckets,
invalid_label => $invalid_label
);
my $stack = mx->rnn->SequentialRNNCell();
for my $i (0..$num_layers-1)
{
$stack->add(mx->rnn->LSTMCell(num_hidden => $num_hidden, prefix => "lstm_l${i}_"));
}
my $sym_gen = sub {
my $seq_len = shift;
my $data = mx->sym->Variable('data');
my $label = mx->sym->Variable('softmax_label');
my $embed = mx->sym->Embedding(
data => $data, input_dim => scalar(keys %$vocabulary),
output_dim => $num_embed, name => 'embed'
);
$stack->reset;
my ($outputs, $states) = $stack->unroll($seq_len, inputs => $embed, merge_outputs => 1);
my $pred = mx->sym->Reshape($outputs, shape => [-1, $num_hidden]);
$pred = mx->sym->FullyConnected(data => $pred, num_hidden => scalar(keys %$vocabulary), name => 'pred');
$label = mx->sym->Reshape($label, shape => [-1]);
$pred = mx->sym->SoftmaxOutput(data => $pred, label => $label, name => 'softmax');
return ($pred, ['data'], ['softmax_label']);
};
my $contexts;
if(defined $gpus)
{
$contexts = [map { mx->gpu($_) } split(/,/, $gpus)];
}
else
{
$contexts = mx->cpu(0);
}
my $model = mx->mod->BucketingModule(
sym_gen => $sym_gen,
default_bucket_key => $data_train->default_bucket_key,
context => $contexts
);
$model->fit(
$data_train,
eval_data => $data_val,
eval_metric => mx->metric->Perplexity($invalid_label),
kvstore => $kv_store,
optimizer => $optimizer,
optimizer_params => {
learning_rate => $lr,
momentum => $mom,
wd => $wd,
},
initializer => mx->init->Xavier(factor_type => "in", magnitude => 2.34),
num_epoch => $num_epoch,
batch_end_callback => mx->callback->Speedometer($batch_size, $disp_batches),
($chkp_epoch ? (epoch_end_callback => mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch)) : ())
);
=head1 DESCRIPTION
Implements the AI::MXNet::Module::Base API, and allows multiple
symbols to be used depending on the `bucket_key` provided by each different
mini-batch of data
=cut
=head2 new
Parameters
----------
$sym_gen : subref or any perl object that overloads &{} op
A sub when called with a bucket key, returns a list with triple
of ($symbol, $data_names, $label_names).
$default_bucket_key : str or anything else
The key for the default bucket.
$logger : Logger
$context : AI::MXNet::Context or array ref of AI::MXNet::Context objects
Default is cpu(0)
$work_load_list : array ref of Num
Default is undef, indicating uniform workload.
$fixed_param_names: arrayref of str
Default is undef, indicating no network parameters are fixed.
$state_names : arrayref of str
states are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by set_states()
=cut
extends 'AI::MXNet::Module::Base';
has '_sym_gen' => (is => 'ro', init_arg => 'sym_gen', required => 1);
has '_default_bucket_key' => (is => 'rw', init_arg => 'default_bucket_key', required => 1);
has '_context' => (
is => 'ro', isa => 'AI::MXNet::Context|ArrayRef[AI::MXNet::Context]',
lazy => 1, default => sub { AI::MXNet::Context->cpu },
init_arg => 'context'
);
has '_work_load_list' => (is => 'rw', init_arg => 'work_load_list', isa => 'ArrayRef[Num]');
has '_curr_module' => (is => 'rw', init_arg => undef);
has '_curr_bucket_key' => (is => 'rw', init_arg => undef);
has '_buckets' => (is => 'rw', init_arg => undef, default => sub { +{} });
has '_fixed_param_names' => (is => 'rw', isa => 'ArrayRef[Str]', init_arg => 'fixed_param_names');
has '_state_names' => (is => 'rw', isa => 'ArrayRef[Str]', init_arg => 'state_names');
has '_params_dirty' => (is => 'rw', init_arg => undef);
sub BUILD
{
my ($self, $original_params) = @_;
$self->_fixed_param_names([]) unless defined $original_params->{fixed_param_names};
$self->_state_names([]) unless defined $original_params->{state_names};
$self->_params_dirty(0);
my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($self->_default_bucket_key);
$self->_check_input_names($symbol, $data_names//[], "data", 1);
$self->_check_input_names($symbol, $label_names//[], "label", 0);
$self->_check_input_names($symbol, $self->_state_names, "state", 1);
$self->_check_input_names($symbol, $self->_fixed_param_names, "fixed_param", 1);
}
method _reset_bind()
{
$self->binded(0);
$self->_buckets({});
$self->_curr_module(undef);
$self->_curr_bucket_key(undef);
}
method data_names()
{
if($self->binded)
{
return $self->_curr_module->data_names;
}
else
{
return (&{$self->_sym_gen}($self->_default_bucket_key))[1];
}
}
method output_names()
{
if($self->binded)
{
return $self->_curr_module->ouput_names;
}
else
{
my ($symbol) = &{$self->_sym_gen}($self->_default_bucket_key);
return $symbol->list_ouputs;
}
}
method data_shapes()
{
assert($self->binded);
return $self->_curr_module->data_shapes;
}
method label_shapes()
{
assert($self->binded);
return $self->_curr_module->label_shapes;
}
method output_shapes()
{
assert($self->binded);
( run in 0.604 second using v1.01-cache-2.11-cpan-39bf76dae61 )