AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Module/Bucketing.pm  view on Meta::CPAN

        kvstore             => $kv_store,
        optimizer           => $optimizer,
        optimizer_params    => {
                                    learning_rate => $lr,
                                    momentum      => $mom,
                                    wd            => $wd,
                            },
        initializer         => mx->init->Xavier(factor_type => "in", magnitude => 2.34),
        num_epoch           => $num_epoch,
        batch_end_callback  => mx->callback->Speedometer($batch_size, $disp_batches),
        ($chkp_epoch ? (epoch_end_callback  => mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch)) : ())
    );

=head1 DESCRIPTION

    Implements the AI::MXNet::Module::Base API, and allows multiple
    symbols to be used depending on the `bucket_key` provided by each different
    mini-batch of data
=cut


=head2 new

    Parameters
    ----------
    $sym_gen : subref or any perl object that overloads &{} op
        A sub when called with a bucket key, returns a list with triple
        of ($symbol, $data_names, $label_names).
    $default_bucket_key : str or anything else
        The key for the default bucket.
    $logger : Logger
    $context : AI::MXNet::Context or array ref of AI::MXNet::Context objects
        Default is cpu(0)
    $work_load_list : array ref of Num
        Default is undef, indicating uniform workload.
    $fixed_param_names: arrayref of str
        Default is undef, indicating no network parameters are fixed.
    $state_names : arrayref of str
        states are similar to data and label, but not provided by data iterator.
        Instead they are initialized to 0 and can be set by set_states()
=cut

extends 'AI::MXNet::Module::Base';
has '_sym_gen'            => (is => 'ro', init_arg => 'sym_gen', required => 1);
has '_default_bucket_key' => (is => 'rw', init_arg => 'default_bucket_key', required => 1);
has '_context'            => (
    is => 'ro', isa => 'AI::MXNet::Context|ArrayRef[AI::MXNet::Context]',
    lazy => 1, default => sub { AI::MXNet::Context->cpu },
    init_arg => 'context'
);
has '_work_load_list'     => (is => 'rw', init_arg => 'work_load_list', isa => 'ArrayRef[Num]');
has '_curr_module'        => (is => 'rw', init_arg => undef);
has '_curr_bucket_key'    => (is => 'rw', init_arg => undef);
has '_buckets'            => (is => 'rw', init_arg => undef, default => sub { +{} });
has '_fixed_param_names'  => (is => 'rw', isa => 'ArrayRef[Str]', init_arg => 'fixed_param_names');
has '_state_names'        => (is => 'rw', isa => 'ArrayRef[Str]', init_arg => 'state_names');
has '_params_dirty'       => (is => 'rw', init_arg => undef);

sub BUILD
{
    my ($self, $original_params) = @_;
    $self->_fixed_param_names([]) unless defined $original_params->{fixed_param_names};
    $self->_state_names([]) unless defined $original_params->{state_names};
    $self->_params_dirty(0);
    my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($self->_default_bucket_key);
    $self->_check_input_names($symbol, $data_names//[], "data", 1);
    $self->_check_input_names($symbol, $label_names//[], "label", 0);
    $self->_check_input_names($symbol, $self->_state_names, "state", 1);
    $self->_check_input_names($symbol, $self->_fixed_param_names, "fixed_param", 1);
}

method _reset_bind()
{
    $self->binded(0);
    $self->_buckets({});
    $self->_curr_module(undef);
    $self->_curr_bucket_key(undef);
}

method data_names()
{
    if($self->binded)
    {
        return $self->_curr_module->data_names;
    }
    else
    {
        return (&{$self->_sym_gen}($self->_default_bucket_key))[1];
    }
}

method output_names()
{
    if($self->binded)
    {
        return $self->_curr_module->ouput_names;
    }
    else
    {
        my ($symbol) = &{$self->_sym_gen}($self->_default_bucket_key);
        return $symbol->list_ouputs;
    }
}

method data_shapes()
{
    assert($self->binded);
    return $self->_curr_module->data_shapes;
}

method label_shapes()
{
    assert($self->binded);
    return $self->_curr_module->label_shapes;
}

method output_shapes()
{
    assert($self->binded);
    return $self->_curr_module->output_shapes;
}

method get_params()

lib/AI/MXNet/Module/Bucketing.pm  view on Meta::CPAN

    assert($self->binded, 'call bind before switching bucket');
    if(not exists $self->_buckets->{ $bucket_key })
    {
        my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($bucket_key);
        my $module = AI::MXNet::Module->new(
            symbol         => $symbol,
            data_names     => $data_names,
            label_names    => $label_names,
            logger         => $self->logger,
            context        => $self->_context,
            work_load_list => $self->_work_load_list
        );
        $module->bind(
            data_shapes      => $data_shapes,
            label_shapes     => $label_shapes,
            for_training     => $self->_curr_module->for_training,
            inputs_need_grad => $self->_curr_module->inputs_need_grad,
            force_rebind     => 0,
            shared_module    => $self->_buckets->{ $self->_default_bucket_key },
        );
        $self->_buckets->{ $bucket_key } = $module;
    }
    $self->_curr_module($self->_buckets->{ $bucket_key });
    $self->_curr_bucket_key($bucket_key);
}

method init_optimizer(
    Str        :$kvstore='local',
    Optimizer  :$optimizer='sgd',
    HashRef    :$optimizer_params={ learning_rate => 0.01 },
    Bool       :$force_init=0
)
{
    assert($self->binded and $self->params_initialized);
    if($self->optimizer_initialized and not $force_init)
    {
        $self->logger->warning('optimizer already initialized, ignoring.');
        return;
    }

    $self->_curr_module->init_optimizer(
        kvstore           => $kvstore,
        optimizer         => $optimizer,
        optimizer_params  => $optimizer_params,
        force_init        => $force_init
    );
    for my $mod (values %{ $self->_buckets })
    {
        if($mod ne $self->_curr_module)
        {
            $mod->borrow_optimizer($self->_curr_module);
        }
    }
    $self->optimizer_initialized(1);
}

method prepare(AI::MXNet::DataBatch $data_batch)
{
    assert($self->binded and $self->params_initialized);
    ## perform bind if have not done so yet
    my $original_bucket_key = $self->_curr_bucket_key;
    $self->switch_bucket(
        bucket_key   => $data_batch->bucket_key,
        data_shapes  => $data_batch->provide_data,
        label_shapes => $data_batch->provide_label
    );
    # switch back
    $self->switch_bucket(bucket_key => $original_bucket_key);
}

method forward(
    AI::MXNet::DataBatch  $data_batch,
    Bool                 :$is_train=
)
{
    assert($self->binded and $self->params_initialized);
    $self->switch_bucket(
        bucket_key   => $data_batch->bucket_key,
        data_shapes  => $data_batch->provide_data,
        label_shapes => $data_batch->provide_label
    );
    $self->_curr_module->forward($data_batch, is_train => $is_train);
}

method backward(Maybe[ArrayRef[AI::MXNet::NDArray]|AI::MXNet::NDArray] $out_grads=)
{
    assert($self->binded and $self->params_initialized);
    $self->_curr_module->backward($out_grads);
}

method update()
{
    assert($self->binded and $self->params_initialized and $self->optimizer_initialized);
    $self->_params_dirty(1);
    $self->_curr_module->update;
}

method get_outputs(Bool $merge_multi_context=1)
{
    assert($self->binded and $self->params_initialized);
    return $self->_curr_module->get_outputs($merge_multi_context);
}

method get_input_grads(Bool $merge_multi_context=1)
{
    assert($self->binded and $self->params_initialized and $self->inputs_need_grad);
    return $self->_curr_module->get_input_grads($merge_multi_context);
}

method update_metric(
    AI::MXNet::EvalMetric $eval_metric,
    ArrayRef[AI::MXNet::NDArray] $labels
)
{
    assert($self->binded and $self->params_initialized);
    $self->_curr_module->update_metric($eval_metric, $labels);
}

method symbol()
{
    assert($self->binded);
    return $self->_curr_module->symbol;
}

method get_symbol()
{
    assert($self->binded);
    return $self->_buckets->{ $self->_default_bucket_key }->symbol;



( run in 0.849 second using v1.01-cache-2.11-cpan-39bf76dae61 )