AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/KVStore.pm  view on Meta::CPAN

        [[ 2.  2.  2.]
        [ 2.  2.  2.]]

    >>> # pull into multiple devices
    >>> $b = [map { mx->nd->ones($shape, $_) } @$gpus]
    >>> $kv->pull(3, out=>$b)
    >>> print $b->[1]->aspdl()
        [[ 2.  2.  2.]
        [ 2.  2.  2.]]

    >>> # pull a list of key-value pairs.
    >>> # On single device
    >>> $keys = [5, 7, 9]
    >>> $b = [map { mx->nd->zeros($shape) } 0..@$keys-1]
    >>> $kv->pull($keys, out=>$b)
    >>> print $b->[1]->aspdl()
        [[ 2.  2.  2.]
        [ 2.  2.  2.]]
    >>> # On multiple devices
    >>> $b = [map { [map { mx->nd->ones($shape, ctx => $_) } @$gpus ] } 0..@$keys-1]
    >>> $kv->pull($keys, out=>$b)
    >>> print $b->[1][1]->aspdl()
        [[ 2.  2.  2.]
        [ 2.  2.  2.]]
=cut

method pull(
    Str|ArrayRef[Str] $key,
    AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]|ArrayRef[ArrayRef[AI::MXNet::NDArray]] :$out,
    Int :$priority=0
)
{
    my ($keys, $vals) = _key_value($key, $out);
    check_call(
        AI::MXNetCAPI::KVStorePullEx(
            $self->handle, scalar(@{ $keys }), $keys, $vals, $priority
        )
    );
}

=head2  set_optimizer

    Register an optimizer to the store

    If there are multiple machines, this process (should be a worker node)
    will pack this optimizer and send it to all servers. It returns after
    this action is done.

    Parameters
    ----------
    optimizer : Optimizer
        the optimizer
=cut

method set_optimizer(AI::MXNet::Optimizer $optimizer)
{
    my $is_worker = check_call(AI::MXNetCAPI::KVStoreIsWorkerNode());
    if($self->type eq 'dist' and $is_worker)
    {
        my $optim_str = MIME::Base64::encode_base64(Storable::freeze($optimizer), "");
        $self->_send_command_to_servers(0, $optim_str);
    }
    else
    {
        $self->_updater(AI::MXNet::Optimizer->get_updater($optimizer));
        $self->_set_updater(sub { &{$self->_updater}(@_) });
    }
}

=head2  type

    Get the type of this kvstore

    Returns
    -------
    type : str
        the string type
=cut

method type()
{
    return scalar(check_call(AI::MXNetCAPI::KVStoreGetType($self->handle)));
}

=head2  rank

    Get the rank of this worker node

    Returns
    -------
    rank : int
        The rank of this node, which is in [0, get_num_workers())
=cut

method rank()
{
    return scalar(check_call(AI::MXNetCAPI::KVStoreGetRank($self->handle)));
}

=head2  num_workers

    Get the number of worker nodes

    Returns
    -------
    size :int
        The number of worker nodes
=cut

method num_workers()
{
    return scalar(check_call(AI::MXNetCAPI::KVStoreGetGroupSize($self->handle)));
}

=head2 save_optimizer_states

    Save optimizer (updater) state to file

    Parameters
    ----------
    fname : str

lib/AI/MXNet/KVStore.pm  view on Meta::CPAN

    multi-machines.

    Parameters
    ----------
    updater : function
        the updater function

    Examples
    --------
    >>> my $update = sub { my ($key, input, stored) = @_;
        ...     print "update on key: $key\n";
        ...     $stored += $input * 2; };
        >>> $kv->_set_updater($update)
        >>> $kv->pull(3, out=>$a)
        >>> print $a->aspdl()
        [[ 4.  4.  4.]
        [ 4.  4.  4.]]
        >>> $kv->push(3, mx->nd->ones($shape))
        update on key: 3
        >>> $kv->pull(3, out=>$a)
        >>> print $a->aspdl()
        [[ 6.  6.  6.]
        [ 6.  6.  6.]]
=cut

method _set_updater(CodeRef $updater_func)
{
    $self->_updater_func(
        sub {
            my ($index, $input_handle, $storage_handle) = @_;
            $updater_func->(
                $index,
                AI::MXNet::NDArray->new(handle => $input_handle),
                AI::MXNet::NDArray->new(handle => $storage_handle)
            );
        }
    );
    check_call(
        AI::MXNetCAPI::KVStoreSetUpdater(
            $self->handle,
            $self->_updater_func
        )
    );
}

=head2 _barrier

    Global barrier between all worker nodes.

    For example, assume there are n machines, we want to let machine 0 first
    init the values, and then pull the inited value to all machines. Before
    pulling, we can place a barrier to guarantee that the initialization is
    finished.
=cut

method _barrier()
{
    check_call(AI::MXNetCAPI::KVStoreBarrier($self->handle));
}

=head2 _send_command_to_servers

    Send a command to all server nodes
    Send a command to all server nodes, which will make each server node run
    KVStoreServer.controller
    This function returns after the command has been executed in all server
    nodes.

    Parameters
    ----------
    head : int
        the head of the command
    body : str
        the body of the command
=cut

method _send_command_to_servers(Int $head, Str $body)
{
    check_call(
        AI::MXNetCAPI::KVStoreSendCommmandToServers(
            $self->handle,
            $head,
            $body
        )
    );
}

=head2 create

    Create a new KVStore.

    Parameters
    ----------
    name : {'local'}
    The type of KVStore
        - local works for multiple devices on a single machine (single process)
        - dist works for multi-machines (multiple processes)
    Returns
    -------
    kv : KVStore
        The created AI::MXNet::KVStore
=cut

method create(Str $name='local')
{
    my $handle = check_call(AI::MXNetCAPI::KVStoreCreate($name));
    return __PACKAGE__->new(handle => $handle);
}

sub _key_value
{
    my ($keys, $vals) = @_;
    if(not ref $keys)
    {
        if(blessed $vals)
        {
            return ([$keys], [$vals->handle]);
        }
        else
        {
            for my $value (@{ $vals })
            {
                assert(blessed($value) and $value->isa('AI::MXNet::NDArray'));
                return ([($keys)x@$vals], [map { $_->handle } @$vals]);
            }
        }
    }
    else
    {
        assert(not blessed($vals) and @$keys == @$vals);
        my @c_keys;
        my @c_vals;
        zip(sub {
            my ($key, $val) = @_;
            my ($c_key, $c_val) = _key_value($key, $val);
            push @c_keys, @$c_key;
            push @c_vals, @$c_val;



( run in 1.586 second using v1.01-cache-2.11-cpan-d8267643d1d )