AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/RNN/IO.pm view on Meta::CPAN
$ndiscard += 1;
next;
}
my $buff = AI::MXNet::NDArray->full(
[$self->buckets->[$buck]],
$self->invalid_label,
dtype => $self->dtype
)->aspdl;
$buff->slice([0, @{ $self->sentences->[$i] }-1]) .= pdl($self->sentences->[$i]);
push @{ $self->data->[$buck] }, $buff;
}
$self->data([map { pdl(PDL::Type->new(DTYPE_MX_TO_PDL->{$self->dtype}), $_) } @{$self->data}]);
AI::MXNet::Logging->warning("discarded $ndiscard sentences longer than the largest bucket.")
if $ndiscard;
$self->nddata([]);
$self->ndlabel([]);
$self->major_axis(index($self->layout, 'N'));
$self->default_bucket_key(max(@{ $self->buckets }));
my $shape;
if($self->major_axis == 0)
{
$shape = [$self->batch_size, $self->default_bucket_key];
}
elsif($self->major_axis == 1)
{
$shape = [$self->default_bucket_key, $self->batch_size];
}
else
{
confess("Invalid layout ${\ $self->layout }: Must by NT (batch major) or TN (time major)");
}
$self->provide_data([
AI::MXNet::DataDesc->new(
name => $self->data_name,
shape => $shape,
dtype => $self->dtype,
layout => $self->layout
)
]);
$self->provide_label([
AI::MXNet::DataDesc->new(
name => $self->label_name,
shape => $shape,
dtype => $self->dtype,
layout => $self->layout
)
]);
$self->idx([]);
enumerate(sub {
my ($i, $buck) = @_;
my $buck_len = $buck->shape->at(-1);
for my $j (0..($buck_len - $self->batch_size))
{
if(not $j%$self->batch_size)
{
push @{ $self->idx }, [$i, $j];
}
}
}, $self->data);
$self->curr_idx(0);
$self->reset;
}
method reset()
{
$self->curr_idx(0);
@{ $self->idx } = shuffle(@{ $self->idx });
$self->nddata([]);
$self->ndlabel([]);
for my $buck (@{ $self->data })
{
$buck = pdl_shuffle($buck);
my $label = $buck->zeros;
$label->slice([0, -2], 'X') .= $buck->slice([1, -1], 'X');
$label->slice([-1, -1], 'X') .= $self->invalid_label;
push @{ $self->nddata }, AI::MXNet::NDArray->array($buck, dtype => $self->dtype);
push @{ $self->ndlabel }, AI::MXNet::NDArray->array($label, dtype => $self->dtype);
}
}
method next()
{
return undef if($self->curr_idx == @{ $self->idx });
my ($i, $j) = @{ $self->idx->[$self->curr_idx] };
$self->curr_idx($self->curr_idx + 1);
my ($data, $label);
if($self->major_axis == 1)
{
$data = $self->nddata->[$i]->slice([$j, $j+$self->batch_size-1])->T;
$label = $self->ndlabel->[$i]->slice([$j, $j+$self->batch_size-1])->T;
}
else
{
$data = $self->nddata->[$i]->slice([$j, $j+$self->batch_size-1]);
$label = $self->ndlabel->[$i]->slice([$j, $j+$self->batch_size-1]);
}
return AI::MXNet::DataBatch->new(
data => [$data],
label => [$label],
bucket_key => $self->buckets->[$i],
pad => 0,
provide_data => [
AI::MXNet::DataDesc->new(
name => $self->data_name,
shape => $data->shape,
dtype => $self->dtype,
layout => $self->layout
)
],
provide_label => [
AI::MXNet::DataDesc->new(
name => $self->label_name,
shape => $label->shape,
dtype => $self->dtype,
layout => $self->layout
)
],
);
}
1;
( run in 1.359 second using v1.01-cache-2.11-cpan-39bf76dae61 )