AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Symbol.pm  view on Meta::CPAN

        push @provided_arg_shape_idx, scalar(@provided_arg_shape_data);
    }
    $num_provided_arg_types = @provided_arg_type_names;

    my $provided_req_type_list_len = 0;
    my @provided_grad_req_types;
    my @provided_grad_req_names;
    if(defined $grad_req)
    {
        if(not ref $grad_req)
        {
            push @provided_grad_req_types, $grad_req;
        }
        elsif(ref $grad_req eq 'ARRAY')
        {
            assert((@{ $grad_req } != 0), 'grad_req in simple_bind cannot be an empty list');
            @provided_grad_req_types = @{ $grad_req };
            $provided_req_type_list_len = @provided_grad_req_types;
        }
        elsif(ref $grad_req eq 'HASH')
        {
            assert((keys %{ $grad_req } != 0), 'grad_req in simple_bind cannot be an empty hash');
            while(my ($k, $v) = each %{ $grad_req })
            {
                push @provided_grad_req_names, $k;
                push @provided_grad_req_types, $v;
            }
            $provided_req_type_list_len = @provided_grad_req_types;
        }
    }
    my $num_ctx_map_keys = 0;
    my @ctx_map_keys;
    my @ctx_map_dev_types;
    my @ctx_map_dev_ids;
    if(defined $group2ctx)
    {
        while(my ($k, $v) = each %{ $group2ctx })
        {
            push @ctx_map_keys, $k;
            push @ctx_map_dev_types, $v->device_type_id;
            push @ctx_map_dev_ids, $v->device_id;
        }
        $num_ctx_map_keys = @ctx_map_keys;
    }

    my @shared_arg_name_list;
    if(defined $shared_arg_names)
    {
        @shared_arg_name_list = @{ $shared_arg_names };
    }
    my %shared_data;
    if(defined $shared_buffer)
    {
        while(my ($k, $v) = each %{ $shared_buffer })
        {
            $shared_data{$k} = $v->handle;
        }
    }
    my $shared_exec_handle = defined $shared_exec ? $shared_exec->handle : undef;
    my (
        $updated_shared_data,
        $in_arg_handles,
        $arg_grad_handles,
        $aux_state_handles,
        $exe_handle
    );
    eval {
        ($updated_shared_data, $in_arg_handles, $arg_grad_handles, $aux_state_handles, $exe_handle)
            =
        check_call(
            AI::MXNetCAPI::ExecutorSimpleBind(
                $self->handle,
                $ctx->device_type_id,
                $ctx->device_id,
                $num_ctx_map_keys,
                \@ctx_map_keys,
                \@ctx_map_dev_types,
                \@ctx_map_dev_ids,
                $provided_req_type_list_len,
                \@provided_grad_req_names,
                \@provided_grad_req_types,
                scalar(@provided_arg_shape_names),
                \@provided_arg_shape_names,
                \@provided_arg_shape_data,
                \@provided_arg_shape_idx,
                $num_provided_arg_types,
                \@provided_arg_type_names,
                \@provided_arg_type_data,
                scalar(@shared_arg_name_list),
                \@shared_arg_name_list,
                defined $shared_buffer ? \%shared_data : undef,
                $shared_exec_handle
            )
        );
    };
    if($@)
    {
        confess(
            "simple_bind failed: Error: $@; Arguments: ".
            Data::Dumper->new(
                [$shapes//{}]
            )->Purity(1)->Deepcopy(1)->Terse(1)->Dump
        );
    }
    if(defined $shared_buffer)
    {
        while(my ($k, $v) = each %{ $updated_shared_data })
        {
            $shared_buffer->{$k} = AI::MXNet::NDArray->new(handle => $v);
        }
    }
    my @arg_arrays  = map { AI::MXNet::NDArray->new(handle => $_) } @{ $in_arg_handles };
    my @grad_arrays = map { defined $_ ? AI::MXNet::NDArray->new(handle => $_) : undef  } @{ $arg_grad_handles };
    my @aux_arrays  = map { AI::MXNet::NDArray->new(handle => $_) } @{ $aux_state_handles };
    my $executor = AI::MXNet::Executor->new(
        handle    => $exe_handle,
        symbol    => $self,
        ctx       => $ctx,
        grad_req  => $grad_req,
        group2ctx => $group2ctx
    );
    $executor->arg_arrays(\@arg_arrays);
    $executor->grad_arrays(\@grad_arrays);
    $executor->aux_arrays(\@aux_arrays);
    return $executor;
}

=head2 bind

    Bind current symbol to get an executor.

    Parameters
    ----------
    :$ctx : AI::MXNet::Context
        The device context the generated executor to run on.

    :$args : HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]
        Input arguments to the symbol.
            - If type is array ref of NDArray, the position is in the same order of list_arguments.
            - If type is hash ref of str to NDArray, then it maps the name of arguments
                to the corresponding NDArray.
            - In either case, all the arguments must be provided.

    :$args_grad : Maybe[HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]]
        When specified, args_grad provide NDArrays to hold
        the result of gradient value in backward.
            - If type is array ref of NDArray, the position is in the same order of list_arguments.
            - If type is hash ref of str to NDArray, then it maps the name of arguments
                to the corresponding NDArray.
            - When the type is hash ref of str to NDArray, users only need to provide the dict
                for needed argument gradient.
        Only the specified argument gradient will be calculated.

    :$grad_req : {'write', 'add', 'null'}, or array ref of str or hash ref of str to str, optional
        Specifies how we should update the gradient to the args_grad.
            - 'write' means everytime gradient is write to specified args_grad NDArray.
            - 'add' means everytime gradient is add to the specified NDArray.
            - 'null' means no action is taken, the gradient may not be calculated.

    :$aux_states : array ref of NDArray, or hash ref of str to NDArray, optional
        Input auxiliary states to the symbol, only need to specify when
        list_auxiliary_states is not empty.
            - If type is array ref of NDArray, the position is in the same order of list_auxiliary_states
            - If type is hash ref of str to NDArray, then it maps the name of auxiliary_states
                to the corresponding NDArray,
            - In either case, all the auxiliary_states need to be provided.



( run in 2.212 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )