AI-MXNet
view release on metacpan or search on metacpan
examples/cudnn_lstm_bucketing.pl view on Meta::CPAN
}
else
{
$stack = mx->rnn->FusedRNNCell(
$num_hidden, num_layers => $num_layers,
mode=>'lstm', bidirectional => $bidirectional
)->unfuse()
}
my $sym_gen = sub {
my $seq_len = shift;
my $data = mx->sym->Variable('data');
my $label = mx->sym->Variable('softmax_label');
my $embed = mx->sym->Embedding(
data => $data, input_dim => scalar(keys %$vocab),
output_dim => $num_embed, name => 'embed'
);
$stack->reset;
my ($outputs, $states) = $stack->unroll($seq_len, inputs => $embed, merge_outputs => 1);
my $pred = mx->sym->Reshape($outputs, shape => [-1, $num_hidden*(1+$bidirectional)]);
$pred = mx->sym->FullyConnected(data => $pred, num_hidden => scalar(keys %$vocab), name => 'pred');
$label = mx->sym->Reshape($label, shape => [-1]);
$pred = mx->sym->SoftmaxOutput(data => $pred, label => $label, name => 'softmax');
return ($pred, ['data'], ['softmax_label']);
};
my $contexts;
if($gpus)
{
$contexts = [map { mx->gpu($_) } split(/,/, $gpus)];
}
else
{
$contexts = mx->cpu(0);
}
my ($arg_params, $aux_params);
if($load_epoch)
{
(undef, $arg_params, $aux_params) = mx->rnn->load_rnn_checkpoint(
$stack, $model_prefix, $load_epoch);
}
my $model = mx->mod->BucketingModule(
sym_gen => $sym_gen,
default_bucket_key => $data_val->default_bucket_key,
context => $contexts
);
$model->bind(
data_shapes => $data_val->provide_data,
label_shapes => $data_val->provide_label,
for_training => 0,
force_rebind => 0
);
$model->set_params($arg_params, $aux_params);
my $score = $model->score($data_val,
mx->metric->Perplexity($invalid_label),
batch_end_callback=>mx->callback->Speedometer($batch_size, 5)
);
};
if($num_layers >= 4 and split(/,/,$gpus) >= 4 and not $stack_rnn)
{
print("WARNING: stack-rnn is recommended to train complex model on multiple GPUs\n");
}
if($do_test)
{
# Demonstrates how to load a model trained with CuDNN RNN and predict
# with non-fused MXNet symbol
$test->();
}
else
{
$train->();
}
( run in 1.835 second using v1.01-cache-2.11-cpan-fe3c2283af0 )