AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Executor/Group.pm view on Meta::CPAN
{
@grad_req{ @{ $self->_p->arg_names } } = @{ $self->grad_req };
}
else
{
for my $k (@{ $self->_p->arg_names })
{
if(exists $param_names{ $k })
{
$grad_req{$k} = exists $fixed_param_names{ $k } ? 'null' : 'write';
}
elsif(exists $data_names{ $k })
{
$grad_req{$k} = $self->inputs_need_grad ? 'write' : 'null';
}
else
{
$grad_req{$k} = 'null';
}
}
%grad_req = (%grad_req, %{ $self->grad_req });
}
$self->grad_req(\%grad_req);
if(defined $self->shared_group)
{
$self->_p->shared_data_arrays($self->shared_group->_p->shared_data_arrays);
}
else
{
$self->_p->shared_data_arrays([map { +{} } 0..@{ $self->contexts }-1]);
}
$self->_p->output_layouts([
map {
AI::MXNet::DataDesc->get_batch_axis($self->symbol->slice($_)->attr('__layout__'))
} @{ $self->symbol->list_outputs }
]);
$self->bind_exec($self->data_shapes, $self->label_shapes, $self->shared_group);
}
=decide_slices
Decide the slices for each context according to the workload.
Parameters
----------
$data_shapes : ArrayRef[AI::MXNet::DataDesc]
=cut
method decide_slices(ArrayRef[AI::MXNet::DataDesc] $data_shapes)
{
confess("empty data_shapes array") unless @{ $data_shapes } > 0;
my $major_axis = [map { AI::MXNet::DataDesc->get_batch_axis($_->layout) } @{ $data_shapes }];
zip(sub {
my ($desc, $axis) = @_;
return if($axis == -1);
my $batch_size = $desc->shape->[$axis];
if(defined $self->_p->batch_size)
{
confess(
"all data must have the same batch size: "
. sprintf("batch_size = %d, but ", $self->_p->batch_size)
. sprintf("%s has shape %s", $desc->name, '('. join(',', @{ $desc->shape }) . ')')
) unless $batch_size == $self->_p->batch_size;
}
else
{
$self->_p->batch_size($batch_size);
$self->_p->slices(AI::MXNet::Executor::Group::_split_input_slice($self->_p->batch_size, $self->workload));
}
}, $data_shapes, $major_axis);
return $major_axis;
}
# Collect internal arrays from executors.
method _collect_arrays()
{
# convenient data structures
$self->_p->data_arrays([]);
for my $d (@{ $self->data_shapes })
{
my $name = $d->name;
my @tmp;
for my $i (0..@{ $self->_p->execs }-1)
{
push @tmp, [ $self->_p->slices->[$i], $self->_p->execs->[$i]->arg_dict->{$name} ];
}
push @{ $self->_p->data_arrays }, \@tmp;
}
if(defined $self->label_shapes)
{
$self->_p->label_arrays([]);
for my $l (@{ $self->label_shapes })
{
my $name = $l->name;
my @tmp;
for my $i (0..@{ $self->_p->execs }-1)
{
push @tmp, [ $self->_p->slices->[$i], $self->_p->execs->[$i]->arg_dict->{$name} ];
}
push @{ $self->_p->label_arrays }, \@tmp;
}
}
$self->_p->param_arrays([]);
my %param_names = map { $_ => 1 } @{ $self->param_names };
for my $i (0..@{ $self->_p->arg_names }-1)
{
my $name = $self->_p->arg_names->[$i];
if(exists $param_names{$name})
{
my @tmp;
for my $exec (@{ $self->_p->execs })
{
push @tmp, $exec->arg_arrays->[$i];
}
push @{ $self->_p->param_arrays }, \@tmp;
}
}
$self->_p->state_arrays([]);
for my $i (0..@{ $self->state_names }-1)
{
my $name = $self->state_names->[$i];
my @tmp;
( run in 1.297 second using v1.01-cache-2.11-cpan-39bf76dae61 )