AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Metric.pm  view on Meta::CPAN

package AI::MXNet::Metric;
use strict;
use warnings;
use AI::MXNet::Function::Parameters;
use Scalar::Util qw/blessed/;

=head1 NAME

    AI::MXNet::Metric - Online evaluation metric module.
=cut

# Check to see if the two arrays are the same size.
sub _calculate_shape
{
    my $input = shift;
    my ($shape);
    if(blessed($input))
    {
        if($input->isa('PDL'))
        {
            $shape = $input->shape->at(-1);
        }
        else
        {
            $shape = $input->shape->[0];
        }
    }
    else
    {
        $shape = @{ $input };
    }
    return $shape;
}
func check_label_shapes(
    ArrayRef|AI::MXNet::NDArray|PDL $labels,
    ArrayRef|AI::MXNet::NDArray|PDL $preds
)
{
    my ($label_shape, $pred_shape) = (_calculate_shape($labels), _calculate_shape($preds));
    Carp::confess(
        "Shape of labels $label_shape does not "
        ."match shape of predictions $pred_shape"
    ) unless $pred_shape == $label_shape;
}

=head1 DESCRIPTION

    Base class of all evaluation metrics.
=cut

package AI::MXNet::EvalMetric;
use Mouse;
use overload '""' => sub {
    return "EvalMetric: "
            .Data::Dumper->new(
                [shift->get_name_value()]
            )->Purity(1)->Deepcopy(1)->Terse(1)->Dump
},  fallback => 1;
has 'name'       => (is => 'rw', isa => 'Str');
has 'num'        => (is => 'rw', isa => 'Int');
has 'num_inst'   => (is => 'rw', isa => 'Maybe[Int|ArrayRef[Int]]');
has 'sum_metric' => (is => 'rw', isa => 'Maybe[Num|ArrayRef[Num]]');

sub BUILD
{
    shift->reset;
}

method update($label, $pred)
{
    confess('NotImplemented');
}

method reset()
{
    if(not defined $self->num)
    {
        $self->num_inst(0);
        $self->sum_metric(0);
    }
    else
    {
        $self->num_inst([(0) x $self->num]);
        $self->sum_metric([(0) x $self->num]);
    }
}

method get()
{
    if(not defined $self->num)
    {
        if($self->num_inst == 0)
        {
            return ($self->name, 'nan');
        }
        else
        {
            return ($self->name, $self->sum_metric / $self->num_inst);
        }
    }
    else
    {
        my $names = [map { sprintf('%s_%d', $self->name, $_) } 0..$self->num-1];
        my $values = [];
        for (my $i = 0; $i < @{ $self->sum_metric }; $i++)
        {
            my ($x, $y) = ($self->sum_metric->[$i], $self->num_inst->[$i]);
            if($y != 0)
            {
                push (@$values, $x/$y);
            }
            else
            {
                push (@$values, 'nan');
            }

 view all matches for this distribution
 view release on metacpan -  search on metacpan

( run in 2.119 seconds using v1.00-cache-2.02-grep-82fe00e-cpan-9f2165ba459b )