AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Symbol.pm view on Meta::CPAN
push @provided_arg_shape_idx, scalar(@provided_arg_shape_data);
}
$num_provided_arg_types = @provided_arg_type_names;
my $provided_req_type_list_len = 0;
my @provided_grad_req_types;
my @provided_grad_req_names;
if(defined $grad_req)
{
if(not ref $grad_req)
{
push @provided_grad_req_types, $grad_req;
}
elsif(ref $grad_req eq 'ARRAY')
{
assert((@{ $grad_req } != 0), 'grad_req in simple_bind cannot be an empty list');
@provided_grad_req_types = @{ $grad_req };
$provided_req_type_list_len = @provided_grad_req_types;
}
elsif(ref $grad_req eq 'HASH')
{
assert((keys %{ $grad_req } != 0), 'grad_req in simple_bind cannot be an empty hash');
while(my ($k, $v) = each %{ $grad_req })
{
push @provided_grad_req_names, $k;
push @provided_grad_req_types, $v;
}
$provided_req_type_list_len = @provided_grad_req_types;
}
}
my $num_ctx_map_keys = 0;
my @ctx_map_keys;
my @ctx_map_dev_types;
my @ctx_map_dev_ids;
if(defined $group2ctx)
{
while(my ($k, $v) = each %{ $group2ctx })
{
push @ctx_map_keys, $k;
push @ctx_map_dev_types, $v->device_type_id;
push @ctx_map_dev_ids, $v->device_id;
}
$num_ctx_map_keys = @ctx_map_keys;
}
my @shared_arg_name_list;
if(defined $shared_arg_names)
{
@shared_arg_name_list = @{ $shared_arg_names };
}
my %shared_data;
if(defined $shared_buffer)
{
while(my ($k, $v) = each %{ $shared_buffer })
{
$shared_data{$k} = $v->handle;
}
}
my $shared_exec_handle = defined $shared_exec ? $shared_exec->handle : undef;
my (
$updated_shared_data,
$in_arg_handles,
$arg_grad_handles,
$aux_state_handles,
$exe_handle
);
eval {
($updated_shared_data, $in_arg_handles, $arg_grad_handles, $aux_state_handles, $exe_handle)
=
check_call(
AI::MXNetCAPI::ExecutorSimpleBind(
$self->handle,
$ctx->device_type_id,
$ctx->device_id,
$num_ctx_map_keys,
\@ctx_map_keys,
\@ctx_map_dev_types,
\@ctx_map_dev_ids,
$provided_req_type_list_len,
\@provided_grad_req_names,
\@provided_grad_req_types,
scalar(@provided_arg_shape_names),
\@provided_arg_shape_names,
\@provided_arg_shape_data,
\@provided_arg_shape_idx,
$num_provided_arg_types,
\@provided_arg_type_names,
\@provided_arg_type_data,
scalar(@shared_arg_name_list),
\@shared_arg_name_list,
defined $shared_buffer ? \%shared_data : undef,
$shared_exec_handle
)
);
};
if($@)
{
confess(
"simple_bind failed: Error: $@; Arguments: ".
Data::Dumper->new(
[$shapes//{}]
)->Purity(1)->Deepcopy(1)->Terse(1)->Dump
);
}
if(defined $shared_buffer)
{
while(my ($k, $v) = each %{ $updated_shared_data })
{
$shared_buffer->{$k} = AI::MXNet::NDArray->new(handle => $v);
}
}
my @arg_arrays = map { AI::MXNet::NDArray->new(handle => $_) } @{ $in_arg_handles };
my @grad_arrays = map { defined $_ ? AI::MXNet::NDArray->new(handle => $_) : undef } @{ $arg_grad_handles };
my @aux_arrays = map { AI::MXNet::NDArray->new(handle => $_) } @{ $aux_state_handles };
my $executor = AI::MXNet::Executor->new(
handle => $exe_handle,
symbol => $self,
ctx => $ctx,
grad_req => $grad_req,
group2ctx => $group2ctx
);
$executor->arg_arrays(\@arg_arrays);
$executor->grad_arrays(\@grad_arrays);
$executor->aux_arrays(\@aux_arrays);
return $executor;
}
=head2 bind
Bind current symbol to get an executor.
Parameters
----------
:$ctx : AI::MXNet::Context
The device context the generated executor to run on.
:$args : HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]
Input arguments to the symbol.
- If type is array ref of NDArray, the position is in the same order of list_arguments.
- If type is hash ref of str to NDArray, then it maps the name of arguments
to the corresponding NDArray.
- In either case, all the arguments must be provided.
:$args_grad : Maybe[HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]]
When specified, args_grad provide NDArrays to hold
the result of gradient value in backward.
- If type is array ref of NDArray, the position is in the same order of list_arguments.
- If type is hash ref of str to NDArray, then it maps the name of arguments
to the corresponding NDArray.
- When the type is hash ref of str to NDArray, users only need to provide the dict
for needed argument gradient.
Only the specified argument gradient will be calculated.
:$grad_req : {'write', 'add', 'null'}, or array ref of str or hash ref of str to str, optional
Specifies how we should update the gradient to the args_grad.
- 'write' means everytime gradient is write to specified args_grad NDArray.
- 'add' means everytime gradient is add to the specified NDArray.
- 'null' means no action is taken, the gradient may not be calculated.
:$aux_states : array ref of NDArray, or hash ref of str to NDArray, optional
Input auxiliary states to the symbol, only need to specify when
list_auxiliary_states is not empty.
- If type is array ref of NDArray, the position is in the same order of list_auxiliary_states
- If type is hash ref of str to NDArray, then it maps the name of auxiliary_states
to the corresponding NDArray,
- In either case, all the auxiliary_states need to be provided.
( run in 2.212 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )