AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/IO.pm  view on Meta::CPAN

    {
        my $data  = shift;
        return $class->$orig(data => $data, @_);
    }
    return $class->$orig(@_);
};

sub BUILD
{
    my $self  = shift;
    my $data  = AI::MXNet::IO->init_data($self->data,  allow_empty => 0, default_name => 'data');
    my $label = AI::MXNet::IO->init_data($self->label, allow_empty => 1, default_name => $self->label_name);
    my $num_data  = $data->[0][1]->shape->[0];
    confess("size of data dimension 0 $num_data < batch_size ${\ $self->batch_size }")
        unless($num_data >= $self->batch_size);
    if($self->_shuffle)
    {
        my @idx = shuffle(0..$num_data-1);
        $_->[1] = AI::MXNet::NDArray->array(pdl_shuffle($_->[1]->aspdl, \@idx)) for @$data;
        $_->[1] = AI::MXNet::NDArray->array(pdl_shuffle($_->[1]->aspdl, \@idx)) for @$label;
    }
    if($self->last_batch_handle eq 'discard')
    {
        my $new_n = $num_data - $num_data % $self->batch_size - 1;
        $_->[1] = $_->[1]->slice([0, $new_n]) for @$data;
        $_->[1] = $_->[1]->slice([0, $new_n]) for @$label;
    }
    my $data_list  = [map { $_->[1] } (@{ $data }, @{ $label })];
    my $num_source = @{ $data_list };
    my $cursor = -$self->batch_size;
    $self->data($data);
    $self->data_list($data_list);
    $self->label($label);
    $self->num_source($num_source);
    $self->cursor($cursor);
    $self->num_data($num_data);
}

# The name and shape of data provided by this iterator
method provide_data()
{
    return [map {
        my ($k, $v) = @{ $_ };
        my $shape = $v->shape;
        $shape->[0] = $self->batch_size;
        AI::MXNet::DataDesc->new(name => $k, shape => $shape, dtype => $v->dtype)
    } @{ $self->data }];
}

# The name and shape of label provided by this iterator
method provide_label()
{
    return [map {
        my ($k, $v) = @{ $_ };
        my $shape = $v->shape;
        $shape->[0] = $self->batch_size;
        AI::MXNet::DataDesc->new(name => $k, shape => $shape, dtype => $v->dtype)
    } @{ $self->label }];
}

# Ignore roll over data and set to start
method hard_reset()
{
    $self->cursor(-$self->batch_size);
}

method reset()
{
    if($self->last_batch_handle eq 'roll_over' and $self->cursor > $self->num_data)
    {
        $self->cursor(-$self->batch_size + ($self->cursor%$self->num_data)%$self->batch_size);
    }
    else
    {
        $self->cursor(-$self->batch_size);
    }
}

method iter_next()
{
    $self->cursor($self->batch_size + $self->cursor);
    return $self->cursor < $self->num_data;
}

method next()
{
    if($self->iter_next)
    {
        return AI::MXNet::DataBatch->new(
            data  => $self->getdata,
            label => $self->getlabel,
            pad   => $self->getpad,
            index => undef
        );
    }
    else
    {
        return undef;
    }
}

# Load data from underlying arrays, internal use only
method _getdata($data_source)
{
    confess("DataIter needs reset.") unless $self->cursor < $self->num_data;
    if(($self->cursor + $self->batch_size) <= $self->num_data)
    {
        return [
            map {
                $_->[1]->slice([$self->cursor,$self->cursor+$self->batch_size-1])
            } @{ $data_source }
        ];
    }
    else
    {
        my $pad = $self->batch_size - $self->num_data + $self->cursor - 1;
        return [
            map {
                AI::MXNet::NDArray->concatenate(
                    [
                        $_->[1]->slice([$self->cursor, -1]),

lib/AI/MXNet/IO.pm  view on Meta::CPAN

# Create an io iterator by handle.
func _make_io_iterator($handle)
{
    my ($iter_name, $desc,
        $arg_names, $arg_types, $arg_descs
    ) = @{ check_call(AI::MXNetCAPI::DataIterGetIterInfo($handle)) };
    my $param_str = build_param_doc($arg_names, $arg_types, $arg_descs);
    my $doc_str = "$desc\n\n"
                  ."$param_str\n"
                  ."name : string, required.\n"
                  ."    Name of the resulting data iterator.\n\n"
                  ."Returns\n"
                  ."-------\n"
                  ."iterator: DataIter\n"
                  ."    The result iterator.";
    my $iter = sub {
        my $class = shift;
        my (@args, %kwargs);
        if(@_ and ref $_[-1] eq 'HASH')
        {
            %kwargs = %{ pop(@_) };
        }
        @args = @_;
        Carp::confess("$iter_name can only accept keyword arguments")
            if @args;
        for my $key (keys %kwargs)
        {
            $kwargs{ $key } = "(" .join(",", @{ $kwargs{ $key } }) .")"
                if ref $kwargs{ $key } eq 'ARRAY';
        }
        my $handle = check_call(
            AI::MXNetCAPI::DataIterCreateIter(
                $handle,
                scalar(keys %kwargs),
                \%kwargs
            )
        );
        return AI::MXNet::MXDataIter->new(handle => $handle, %kwargs);
    };
    $iter_meta{$iter}{__name__} = $iter_name;
    $iter_meta{$iter}{__doc__}  = $doc_str;
    return $iter;
}

# List and add all the data iterators to current module.
method _init_io_module()
{
    for my $creator (@{ check_call(AI::MXNetCAPI::ListDataIters()) })
    {
        my $data_iter = _make_io_iterator($creator);
        {
            my $name = $iter_meta{ $data_iter }{__name__};
            no strict 'refs';
            {
                *{__PACKAGE__."::$name"} = $data_iter;
            } 
        }
    }
}

# Initialize the io in startups
__PACKAGE__->_init_io_module;

1;



( run in 0.760 second using v1.01-cache-2.11-cpan-39bf76dae61 )