AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Metric.pm view on Meta::CPAN
=head1 NAME
AI::MXNet::Metric - Online evaluation metric module.
=cut
# Check to see if the two arrays are the same size.
sub _calculate_shape
{
my $input = shift;
my ($shape);
if(blessed($input))
{
if($input->isa('PDL'))
{
$shape = $input->shape->at(-1);
}
else
{
$shape = $input->shape->[0];
}
}
else
{
$shape = @{ $input };
}
return $shape;
}
func check_label_shapes(
ArrayRef|AI::MXNet::NDArray|PDL $labels,
ArrayRef|AI::MXNet::NDArray|PDL $preds
)
{
my ($label_shape, $pred_shape) = (_calculate_shape($labels), _calculate_shape($preds));
Carp::confess(
"Shape of labels $label_shape does not "
."match shape of predictions $pred_shape"
) unless $pred_shape == $label_shape;
}
=head1 DESCRIPTION
Base class of all evaluation metrics.
=cut
package AI::MXNet::EvalMetric;
use Mouse;
use overload '""' => sub {
return "EvalMetric: "
.Data::Dumper->new(
[shift->get_name_value()]
)->Purity(1)->Deepcopy(1)->Terse(1)->Dump
}, fallback => 1;
has 'name' => (is => 'rw', isa => 'Str');
has 'num' => (is => 'rw', isa => 'Int');
has 'num_inst' => (is => 'rw', isa => 'Maybe[Int|ArrayRef[Int]]');
has 'sum_metric' => (is => 'rw', isa => 'Maybe[Num|ArrayRef[Num]]');
sub BUILD
{
shift->reset;
}
method update($label, $pred)
{
confess('NotImplemented');
}
method reset()
{
if(not defined $self->num)
{
$self->num_inst(0);
$self->sum_metric(0);
}
else
{
$self->num_inst([(0) x $self->num]);
$self->sum_metric([(0) x $self->num]);
}
}
method get()
{
if(not defined $self->num)
{
if($self->num_inst == 0)
{
return ($self->name, 'nan');
}
else
{
return ($self->name, $self->sum_metric / $self->num_inst);
}
}
else
{
my $names = [map { sprintf('%s_%d', $self->name, $_) } 0..$self->num-1];
my $values = [];
for (my $i = 0; $i < @{ $self->sum_metric }; $i++)
{
my ($x, $y) = ($self->sum_metric->[$i], $self->num_inst->[$i]);
if($y != 0)
{
push (@$values, $x/$y);
}
else
{
push (@$values, 'nan');
}
}
return ($names, $values);
}
}
method get_name_value()
{
my ($name, $value) = $self->get;
$name = [$name] unless ref $name;
$value = [$value] unless ref $value;
my %ret;
@ret{ @$name } = @$value;
return \%ret;
}
package AI::MXNet::CompositeEvalMetric;
use Mouse;
extends 'AI::MXNet::EvalMetric';
has 'metrics' => (is => 'rw', isa => 'ArrayRef[AI::MXNet::EvalMetric]', default => sub { [] });
has '+name' => (default => 'composite');
# Add a child metric.
method add(AI::MXNet::EvalMetric $metric)
{
push @{ $self->metrics }, $metric;
}
# Get a child metric.
method get_metric(int $index)
{
my $max = @{ $self->metrics } - 1;
confess("Metric index $index is out of range 0 and $max")
if $index > $max;
return $self->metrics->[$index];
}
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
for my $metric (@{ $self->metrics })
{
$metric->update($labels, $preds);
}
}
method reset()
{
for my $metric (@{ $self->metrics })
{
$metric->reset;
}
}
method get()
{
my $names = [];
my $results = [];
for my $metric (@{ $self->metrics })
{
my ($name, $result) = $metric->get;
$name = [$name] unless ref $name;
$result = [$result] unless ref $result;
push @$names, @$name;
push @$results, @$result;
}
return ($names, $results);
}
########################
# CLASSIFICATION METRICS
########################
package AI::MXNet::Accuracy;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'accuracy');
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
zip(sub {
my ($label, $pred_label) = @_;
if(join(',', @{$pred_label->shape}) ne join(',', @{$label->shape}))
{
$pred_label = AI::MXNet::NDArray->argmax_channel($pred_label);
}
AI::MXNet::Metric::check_label_shapes($label, $pred_label);
my $sum = ($pred_label->aspdl->flat == $label->aspdl->flat)->sum;
$self->sum_metric($self->sum_metric + $sum);
$self->num_inst($self->num_inst + $pred_label->size);
}, $labels, $preds);
}
package AI::MXNet::TopKAccuracy;
use Mouse;
use List::Util qw/min/;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'top_k_accuracy');
has 'top_k' => (is => 'rw', isa => 'int', default => 1);
sub BUILD
{
my $self = shift;
confess("Please use Accuracy if top_k is no more than 1")
unless $self->top_k > 1;
$self->name($self->name . "_" . $self->top_k);
}
( run in 0.980 second using v1.01-cache-2.11-cpan-39bf76dae61 )