AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Metric.pm view on Meta::CPAN
my ($name, $result) = $metric->get;
$name = [$name] unless ref $name;
$result = [$result] unless ref $result;
push @$names, @$name;
push @$results, @$result;
}
return ($names, $results);
}
########################
# CLASSIFICATION METRICS
########################
package AI::MXNet::Accuracy;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'accuracy');
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
zip(sub {
my ($label, $pred_label) = @_;
if(join(',', @{$pred_label->shape}) ne join(',', @{$label->shape}))
{
$pred_label = AI::MXNet::NDArray->argmax_channel($pred_label);
}
AI::MXNet::Metric::check_label_shapes($label, $pred_label);
my $sum = ($pred_label->aspdl->flat == $label->aspdl->flat)->sum;
$self->sum_metric($self->sum_metric + $sum);
$self->num_inst($self->num_inst + $pred_label->size);
}, $labels, $preds);
}
package AI::MXNet::TopKAccuracy;
use Mouse;
use List::Util qw/min/;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'top_k_accuracy');
has 'top_k' => (is => 'rw', isa => 'int', default => 1);
sub BUILD
{
my $self = shift;
confess("Please use Accuracy if top_k is no more than 1")
unless $self->top_k > 1;
$self->name($self->name . "_" . $self->top_k);
}
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
zip(sub {
my ($label, $pred_label) = @_;
confess('Predictions should be no more than 2 dims')
unless @{ $pred_label->shape } <= 2;
$pred_label = $pred_label->aspdl->qsorti;
$label = $label->astype('int32')->aspdl;
AI::MXNet::Metric::check_label_shapes($label, $pred_label);
my $num_samples = $pred_label->shape->at(-1);
my $num_dims = $pred_label->ndims;
if($num_dims == 1)
{
my $sum = ($pred_label->flat == $label->flat)->sum;
$self->sum_metric($self->sum_metric + $sum);
}
elsif($num_dims == 2)
{
my $num_classes = $pred_label->shape->at(0);
my $top_k = min($num_classes, $self->top_k);
for my $j (0..$top_k-1)
{
my $sum = ($pred_label->slice($num_classes -1 - $j, 'X')->flat == $label->flat)->sum;
$self->sum_metric($self->sum_metric + $sum);
}
}
$self->num_inst($self->num_inst + $num_samples);
}, $labels, $preds);
}
# Calculate the F1 score of a binary classification problem.
package AI::MXNet::F1;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'f1');
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
zip(sub {
my ($label, $pred_label) = @_;
AI::MXNet::Metric::check_label_shapes($label, $pred_label);
$pred_label = $pred_label->aspdl->maximum_ind;
$label = $label->astype('int32')->aspdl;
confess("F1 currently only supports binary classification.")
if $label->uniq->shape->at(0) > 2;
my ($true_positives, $false_positives, $false_negatives) = (0,0,0);
zip(sub{
my ($y_pred, $y_true) = @_;
if($y_pred == 1 and $y_true == 1)
{
$true_positives += 1;
}
elsif($y_pred == 1 and $y_true == 0)
{
$false_positives += 1;
}
elsif($y_pred == 0 and $y_true == 1)
{
$false_negatives += 1;
}
}, $pred_label->unpdl, $label->unpdl);
my $precision;
my $recall;
if($true_positives + $false_positives > 0)
{
$precision = $true_positives / ($true_positives + $false_positives);
}
else
{
$precision = 0;
}
if($true_positives + $false_negatives > 0)
{
$recall = $true_positives / ($true_positives + $false_negatives);
}
else
{
$recall = 0;
}
my $f1_score;
if($precision + $recall > 0)
{
$f1_score = 2 * $precision * $recall / ($precision + $recall);
}
else
{
$f1_score = 0;
}
$self->sum_metric($self->sum_metric + $f1_score);
$self->num_inst($self->num_inst + 1);
}, $labels, $preds);
}
package AI::MXNet::Perplexity;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'Perplexity');
has 'ignore_label' => (is => 'ro', isa => 'Maybe[Int]');
has 'axis' => (is => 'ro', isa => 'Int', default => -1);
around BUILDARGS => sub {
my $orig = shift;
my $class = shift;
return $class->$orig(ignore_label => $_[0]) if @_ == 1;
return $class->$orig(@_);
};
=head1 NAME
AI::MXNet::Perplexity
=cut
=head1 DESCRIPTION
Calculate perplexity.
Parameters
----------
ignore_label : int or undef
index of invalid label to ignore when
counting. usually should be -1. Include
all entries if undef.
axis : int (default -1)
The axis from prediction that was used to
compute softmax. By default uses the last
axis.
=cut
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
my ($loss, $num) = (0, 0);
zip(sub {
my ($label, $pred) = @_;
my $label_shape = $label->shape;
my $pred_shape = $pred->shape;
assert(
(product(@{ $label_shape }) == product(@{ $pred_shape })/$pred_shape->[-1]),
"shape mismatch: (@$label_shape) vs. (@$pred_shape)"
);
$label = $label->as_in_context($pred->context)->reshape([$label->size]);
$pred = AI::MXNet::NDArray->pick($pred, $label->astype('int32'), { axis => $self->axis });
if(defined $self->ignore_label)
{
my $ignore = ($label == $self->ignore_label);
$num -= $ignore->sum->asscalar;
$pred = $pred*(1-$ignore) + $ignore;
}
$loss -= $pred->maximum(1e-10)->log->sum->asscalar;
$num += $pred->size;
}, $labels, $preds);
$self->sum_metric($self->sum_metric + $loss);
$self->num_inst($self->num_inst + $num);
}
method get()
{
return ($self->name, exp($self->sum_metric / $self->num_inst));
}
####################
# REGRESSION METRICS
####################
# Calculate Mean Absolute Error loss
package AI::MXNet::MAE;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'mae');
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
zip(sub {
my ($label, $pred) = @_;
$label = $label->aspdl;
$pred = $pred->aspdl;
if($label->ndims == 1)
{
$label = $label->reshape(1, $label->shape->at(0));
}
$self->sum_metric($self->sum_metric + ($label - $pred)->abs->avg);
$self->num_inst($self->num_inst + 1);
}, $labels, $preds);
}
# Calculate Mean Squared Error loss
package AI::MXNet::MSE;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'mse');
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
zip(sub {
my ($label, $pred) = @_;
$label = $label->aspdl;
$pred = $pred->aspdl;
if($label->ndims == 1)
( run in 0.737 second using v1.01-cache-2.11-cpan-df04353d9ac )