AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/RNN/IO.pm  view on Meta::CPAN

package AI::MXNet::RNN::IO;
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::Function::Parameters;

=encoding UTF-8

=head1 NAME

    AI::MXNet::RNN::IO - Functions for constructing recurrent neural networks.
=cut

=head1 DESCRIPTION

    Functions for constructing recurrent neural networks.
=cut

=head2 encode_sentences

    Encode sentences and (optionally) build a mapping
    from string tokens to integer indices. Unknown keys
    will be added to vocabulary.

    Parameters
    ----------
    $sentences : array ref of array refs of str
        A array ref of sentences to encode. Each sentence
        should be a array ref of string tokens.
    :$vocab : undef or hash ref of str -> int
        Optional input Vocabulary
    :$invalid_label : int, default -1
        Index for invalid token, like <end-of-sentence>
    :$invalid_key : str, default '\n'
        Key for invalid token. Uses '\n' for end
        of sentence by default.
    :$start_label=0 : int
        lowest index.

    Returns
    -------
    $result : array ref of array refs of int
        encoded sentences
    $vocab : hash ref of str -> int
        result vocabulary
=cut


method encode_sentences(
    ArrayRef[ArrayRef]  $sentences,
    Maybe[HashRef]     :$vocab=,
    Int                :$invalid_label=-1,
    Str                :$invalid_key="\n",
    Int                :$start_label=0
)
{
    my $idx = $start_label;
    my $new_vocab;
    if(not defined $vocab)
    {
        $vocab = { $invalid_key => $invalid_label };
        $new_vocab = 1;
    }
    else
    {
        $new_vocab = 0;
    }
    my @res;
    for my $sent (@{ $sentences })
    {
        my @coded;
        for my $word (@{ $sent })
        {
            if(not exists $vocab->{ $word })
            {
                assert($new_vocab, "Unknown token: $word");
                if($idx == $invalid_label)
                {
                    $idx += 1;
                }
                $vocab->{$word} = $idx;
                $idx += 1;
            }
            push @coded, $vocab->{ $word };
        }
        push @res, \@coded;
    }
    return (\@res, $vocab);
}

package AI::MXNet::BucketSentenceIter;

=encoding UTF-8

=head1 NAME

    AI::MXNet::BucketSentenceIter
=cut

=head1 DESCRIPTION

    Simple bucketing iterator for language model.
    Label for each step is constructed from data of
    next step.
=cut

=head2 new

    Parameters
    ----------
    sentences : array ref of array refs of int
        encoded sentences
    batch_size : int
        batch_size of data
    invalid_label : int, default -1
        key for invalid label, e.g. <end-of-sentence>
    dtype : str, default 'float32'
        data type
    buckets : array ref of int
        size of data buckets. Automatically generated if undef.
    data_name : str, default 'data'
        name of data
    label_name : str, default 'softmax_label'
        name of label
    layout : str
        format of data and label. 'NT' means (batch_size, length)
        and 'TN' means (length, batch_size).
=cut

use Mouse;
use AI::MXNet::Base;
use List::Util qw(shuffle max);
extends 'AI::MXNet::DataIter';
has 'sentences'     => (is => 'ro', isa => 'ArrayRef[ArrayRef]', required => 1);
has '+batch_size'   => (is => 'ro', isa => 'Int',                required => 1);
has 'invalid_label' => (is => 'ro', isa => 'Int',   default => -1);
has 'data_name'     => (is => 'ro', isa => 'Str',   default => 'data');
has 'label_name'    => (is => 'ro', isa => 'Str',   default => 'softmax_label');
has 'dtype'         => (is => 'ro', isa => 'Dtype', default => 'float32');
has 'layout'        => (is => 'ro', isa => 'Str',   default => 'NT');
has 'buckets'       => (is => 'rw', isa => 'Maybe[ArrayRef[Int]]');
has [qw/data nddata ndlabel
        major_axis default_bucket_key
        provide_data provide_label
        idx curr_idx
    /]              => (is => 'rw', init_arg => undef);

sub BUILD
{
    my $self = shift;
    if(not defined $self->buckets)
    {
        my @buckets;



( run in 1.011 second using v1.01-cache-2.11-cpan-39bf76dae61 )