AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/RNN/IO.pm  view on Meta::CPAN

            if(not exists $vocab->{ $word })
            {
                assert($new_vocab, "Unknown token: $word");
                if($idx == $invalid_label)
                {
                    $idx += 1;
                }
                $vocab->{$word} = $idx;
                $idx += 1;
            }
            push @coded, $vocab->{ $word };
        }
        push @res, \@coded;
    }
    return (\@res, $vocab);
}

package AI::MXNet::BucketSentenceIter;

=encoding UTF-8

=head1 NAME

    AI::MXNet::BucketSentenceIter
=cut

=head1 DESCRIPTION

    Simple bucketing iterator for language model.
    Label for each step is constructed from data of
    next step.
=cut

=head2 new

    Parameters
    ----------
    sentences : array ref of array refs of int
        encoded sentences
    batch_size : int
        batch_size of data
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : array ref of int
        size of data buckets. Automatically generated if undef.
    data_name : str, default 'data'
        name of data
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
=cut

use Mouse;
use AI::MXNet::Base;
use List::Util qw(shuffle max);
extends 'AI::MXNet::DataIter';
has 'sentences'     => (is => 'ro', isa => 'ArrayRef[ArrayRef]', required => 1);
has '+batch_size'   => (is => 'ro', isa => 'Int',                required => 1);
has 'invalid_label' => (is => 'ro', isa => 'Int',   default => -1);
has 'data_name'     => (is => 'ro', isa => 'Str',   default => 'data');
has 'label_name'    => (is => 'ro', isa => 'Str',   default => 'softmax_label');
has 'dtype'         => (is => 'ro', isa => 'Dtype', default => 'float32');
has 'layout'        => (is => 'ro', isa => 'Str',   default => 'NT');
has 'buckets'       => (is => 'rw', isa => 'Maybe[ArrayRef[Int]]');
has [qw/data nddata ndlabel
        major_axis default_bucket_key
        provide_data provide_label
        idx curr_idx
    /]              => (is => 'rw', init_arg => undef);

sub BUILD
{
    my $self = shift;
    if(not defined $self->buckets)
    {
        my @buckets;
        my $p = pdl([map { scalar(@$_) } @{ $self->sentences }]);
        enumerate(sub {
            my ($i, $j) = @_;
            if($j >= $self->batch_size)
            {
                push @buckets, $i;
            }
        }, $p->histogram(1,0,$p->max+1)->unpdl);
        $self->buckets(\@buckets);
    }
    @{ $self->buckets } = sort { $a <=> $b } @{ $self->buckets };
    my $ndiscard = 0;
    $self->data([map { [] } 0..@{ $self->buckets }-1]);
    for my $i (0..@{$self->sentences}-1)
    {
        my $buck = bisect_left($self->buckets, scalar(@{ $self->sentences->[$i] }));
        if($buck == @{ $self->buckets })
        {
            $ndiscard += 1;
            next;
        }
        my $buff = AI::MXNet::NDArray->full(
            [$self->buckets->[$buck]],
            $self->invalid_label,
            dtype => $self->dtype
        )->aspdl;
        $buff->slice([0, @{ $self->sentences->[$i] }-1]) .= pdl($self->sentences->[$i]);
        push @{ $self->data->[$buck] }, $buff;
    }
    $self->data([map { pdl(PDL::Type->new(DTYPE_MX_TO_PDL->{$self->dtype}), $_) } @{$self->data}]);
    AI::MXNet::Logging->warning("discarded $ndiscard sentences longer than the largest bucket.")
        if $ndiscard;
    $self->nddata([]);
    $self->ndlabel([]);
    $self->major_axis(index($self->layout, 'N'));
    $self->default_bucket_key(max(@{ $self->buckets }));
    my $shape;
    if($self->major_axis == 0)
    {
        $shape = [$self->batch_size, $self->default_bucket_key];
    }

 view all matches for this distribution
 view release on metacpan -  search on metacpan

( run in 0.452 second using v1.00-cache-2.02-grep-82fe00e-cpan-2c419f77a38b )