AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/KVStore.pm  view on Meta::CPAN

package AI::MXNet::KVStore;
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::NDArray;
use AI::MXNet::Optimizer;
use MIME::Base64;
use Storable;
use Mouse;
use AI::MXNet::Function::Parameters;

=head1 NAME

    AI::MXNet::KVStore - Key value store interface of MXNet.

=head1 DESCRIPTION 

    Key value store interface of MXNet for parameter synchronization, over multiple devices.
=cut

has 'handle' => (is => 'ro', isa => 'KVStoreHandle', required => 1);
has '_updater' => (is => 'rw',  isa => 'AI::MXNet::Updater');
has '_updater_func' => (is => 'rw', isa => 'CodeRef');

sub DEMOLISH
{
    check_call(AI::MXNetCAPI::KVStoreFree(shift->handle));
}

=head2  init

    Initialize a single or a sequence of key-value pairs into the store.
    For each key, one must init it before push and pull.
    Only worker 0's (rank == 0) data are used.
    This function returns after data have been initialized successfully

    Parameters
    ----------
    key : str or an array ref of str
        The keys.
    value : NDArray or an array ref of NDArray objects
        The values.

    Examples
    --------
    >>> # init a single key-value pair
    >>> $shape = [2,3]
    >>> $kv = mx->kv->create('local')
    >>> $kv->init(3, mx->nd->ones($shape)*2)
    >>> $a = mx->nd->zeros($shape)
    >>> $kv->pull(3, out=>$a)
    >>> print $a->aspdl
    [[ 2  2  2]
    [ 2  2  2]]

    >>> # init a list of key-value pairs
    >>> $keys = [5, 7, 9]
    >>> $kv->init(keys, [map { mx->nd->ones($shape) } 0..@$keys-1])
=cut

method init(
    Str|ArrayRef[Str] $key,
    AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]|ArrayRef[ArrayRef[AI::MXNet::NDArray]] $value
)
{
    my ($keys, $vals) = _key_value($key, $value);
    check_call(
        AI::MXNetCAPI::KVStoreInitEx(
            $self->handle, scalar(@{ $keys }), $keys, $vals
        )
    );
}

=head2  push

    Push a single or a sequence of key-value pairs into the store.
    Data consistency:
    1. this function returns after adding an operator to the engine.
    2. push is always called after all previous push and pull on the same
        key are finished.
    3. there is no synchronization between workers. One can use _barrier()
    to sync all workers.

    Parameters
    ----------
    key : str or array ref of str
    value : NDArray or array ref of NDArray or array ref of array refs of NDArray
    priority : int, optional
        The priority of the push operation.
        The higher the priority, the faster this action is likely
        to be executed before other push actions.

    Examples
    --------
    >>> # push a single key-value pair

lib/AI/MXNet/KVStore.pm  view on Meta::CPAN

    Returns
    -------
    rank : int
        The rank of this node, which is in [0, get_num_workers())
=cut

method rank()
{
    return scalar(check_call(AI::MXNetCAPI::KVStoreGetRank($self->handle)));
}

=head2  num_workers

    Get the number of worker nodes

    Returns
    -------
    size :int
        The number of worker nodes
=cut

method num_workers()
{
    return scalar(check_call(AI::MXNetCAPI::KVStoreGetGroupSize($self->handle)));
}

=head2 save_optimizer_states

    Save optimizer (updater) state to file

    Parameters
    ----------
    fname : str
        Path to output states file.
=cut

method save_optimizer_states(Str $fname)
{
    confess("Cannot save states for distributed training")
        unless defined $self->_updater;
    open(F, ">:raw", "$fname") or confess("can't open $fname for writing: $!");
    print F $self->_updater->get_states();
    close(F);
}

=head2 load_optimizer_states

    Load optimizer (updater) state from file.

    Parameters
    ----------
    fname : str
        Path to input states file.
=cut

method load_optimizer_states(Str $fname)
{
    confess("Cannot save states for distributed training")
        unless defined $self->_updater;
    open(F, "<:raw", "$fname") or confess("can't open $fname for reading: $!");
    my $data;
    { local($/) = undef; $data = <F>; }
    close(F);
    $self->_updater->set_states($data);
}

=head2 _set_updater

    Set a push updater into the store.

    This function only changes the local store. Use set_optimizer for
    multi-machines.

    Parameters
    ----------
    updater : function
        the updater function

    Examples
    --------
    >>> my $update = sub { my ($key, input, stored) = @_;
        ...     print "update on key: $key\n";
        ...     $stored += $input * 2; };
        >>> $kv->_set_updater($update)
        >>> $kv->pull(3, out=>$a)
        >>> print $a->aspdl()
        [[ 4.  4.  4.]
        [ 4.  4.  4.]]
        >>> $kv->push(3, mx->nd->ones($shape))
        update on key: 3
        >>> $kv->pull(3, out=>$a)
        >>> print $a->aspdl()
        [[ 6.  6.  6.]
        [ 6.  6.  6.]]
=cut

method _set_updater(CodeRef $updater_func)
{
    $self->_updater_func(
        sub {
            my ($index, $input_handle, $storage_handle) = @_;
            $updater_func->(
                $index,
                AI::MXNet::NDArray->new(handle => $input_handle),
                AI::MXNet::NDArray->new(handle => $storage_handle)
            );
        }
    );
    check_call(
        AI::MXNetCAPI::KVStoreSetUpdater(
            $self->handle,
            $self->_updater_func
        )
    );
}

=head2 _barrier

    Global barrier between all worker nodes.

    For example, assume there are n machines, we want to let machine 0 first
    init the values, and then pull the inited value to all machines. Before
    pulling, we can place a barrier to guarantee that the initialization is
    finished.



( run in 1.756 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )