AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/RNN/IO.pm  view on Meta::CPAN

            $ndiscard += 1;
            next;
        }
        my $buff = AI::MXNet::NDArray->full(
            [$self->buckets->[$buck]],
            $self->invalid_label,
            dtype => $self->dtype
        )->aspdl;
        $buff->slice([0, @{ $self->sentences->[$i] }-1]) .= pdl($self->sentences->[$i]);
        push @{ $self->data->[$buck] }, $buff;
    }
    $self->data([map { pdl(PDL::Type->new(DTYPE_MX_TO_PDL->{$self->dtype}), $_) } @{$self->data}]);
    AI::MXNet::Logging->warning("discarded $ndiscard sentences longer than the largest bucket.")
        if $ndiscard;
    $self->nddata([]);
    $self->ndlabel([]);
    $self->major_axis(index($self->layout, 'N'));
    $self->default_bucket_key(max(@{ $self->buckets }));
    my $shape;
    if($self->major_axis == 0)
    {
        $shape = [$self->batch_size, $self->default_bucket_key];
    }
    elsif($self->major_axis == 1)
    {
        $shape = [$self->default_bucket_key, $self->batch_size];
    }
    else
    {
        confess("Invalid layout ${\ $self->layout }: Must by NT (batch major) or TN (time major)");
    }
    $self->provide_data([
        AI::MXNet::DataDesc->new(
            name  => $self->data_name,
            shape => $shape,
            dtype => $self->dtype,
            layout => $self->layout
        )
    ]);
    $self->provide_label([
        AI::MXNet::DataDesc->new(
            name  => $self->label_name,
            shape => $shape,
            dtype => $self->dtype,
            layout => $self->layout
        )
    ]);
    $self->idx([]);
    enumerate(sub {
        my ($i, $buck) = @_;
        my $buck_len = $buck->shape->at(-1);
        for my $j (0..($buck_len - $self->batch_size))
        {
            if(not $j%$self->batch_size)
            {
                push @{ $self->idx }, [$i, $j];
            }
        }
    }, $self->data);
    $self->curr_idx(0);
    $self->reset;
}

method reset()
{
    $self->curr_idx(0);
    @{ $self->idx } = shuffle(@{ $self->idx });
    $self->nddata([]);
    $self->ndlabel([]);
    for my $buck (@{ $self->data })
    {
        $buck = pdl_shuffle($buck);
        my $label = $buck->zeros;
        $label->slice([0, -2], 'X')  .= $buck->slice([1, -1], 'X');
        $label->slice([-1, -1], 'X') .= $self->invalid_label;
        push @{ $self->nddata }, AI::MXNet::NDArray->array($buck, dtype => $self->dtype);
        push @{ $self->ndlabel }, AI::MXNet::NDArray->array($label, dtype => $self->dtype);
    }
}

method next()
{
    return undef if($self->curr_idx == @{ $self->idx });
    my ($i, $j) = @{ $self->idx->[$self->curr_idx] };
    $self->curr_idx($self->curr_idx + 1);
    my ($data, $label);
    if($self->major_axis == 1)
    {
        $data  = $self->nddata->[$i]->slice([$j, $j+$self->batch_size-1])->T;
        $label = $self->ndlabel->[$i]->slice([$j, $j+$self->batch_size-1])->T;
    }
    else
    {
        $data = $self->nddata->[$i]->slice([$j, $j+$self->batch_size-1]);
        $label = $self->ndlabel->[$i]->slice([$j, $j+$self->batch_size-1]);
    }
    return AI::MXNet::DataBatch->new(
        data          => [$data],
        label         => [$label],
        bucket_key    => $self->buckets->[$i],
        pad           => 0,
        provide_data  => [
            AI::MXNet::DataDesc->new(
                name  => $self->data_name,
                shape => $data->shape,
                dtype => $self->dtype,
                layout => $self->layout
            )
        ],
        provide_label => [
            AI::MXNet::DataDesc->new(
                name  => $self->label_name,
                shape => $label->shape,
                dtype => $self->dtype,
                layout => $self->layout
            )
        ],
    );
}

1;



( run in 1.359 second using v1.01-cache-2.11-cpan-39bf76dae61 )