AI-MXNet
view release on metacpan or search on metacpan
lib/AI/MXNet/Metric.pm view on Meta::CPAN
$self->sum_metric($self->sum_metric + $f1_score);
$self->num_inst($self->num_inst + 1);
}, $labels, $preds);
}
package AI::MXNet::Perplexity;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'Perplexity');
has 'ignore_label' => (is => 'ro', isa => 'Maybe[Int]');
has 'axis' => (is => 'ro', isa => 'Int', default => -1);
around BUILDARGS => sub {
my $orig = shift;
my $class = shift;
return $class->$orig(ignore_label => $_[0]) if @_ == 1;
return $class->$orig(@_);
};
=head1 NAME
AI::MXNet::Perplexity
=cut
=head1 DESCRIPTION
Calculate perplexity.
Parameters
----------
ignore_label : int or undef
index of invalid label to ignore when
counting. usually should be -1. Include
all entries if undef.
axis : int (default -1)
The axis from prediction that was used to
compute softmax. By default uses the last
axis.
=cut
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
my ($loss, $num) = (0, 0);
zip(sub {
my ($label, $pred) = @_;
my $label_shape = $label->shape;
my $pred_shape = $pred->shape;
assert(
(product(@{ $label_shape }) == product(@{ $pred_shape })/$pred_shape->[-1]),
"shape mismatch: (@$label_shape) vs. (@$pred_shape)"
);
$label = $label->as_in_context($pred->context)->reshape([$label->size]);
$pred = AI::MXNet::NDArray->pick($pred, $label->astype('int32'), { axis => $self->axis });
if(defined $self->ignore_label)
{
my $ignore = ($label == $self->ignore_label);
$num -= $ignore->sum->asscalar;
$pred = $pred*(1-$ignore) + $ignore;
}
$loss -= $pred->maximum(1e-10)->log->sum->asscalar;
$num += $pred->size;
}, $labels, $preds);
$self->sum_metric($self->sum_metric + $loss);
$self->num_inst($self->num_inst + $num);
}
method get()
{
return ($self->name, exp($self->sum_metric / $self->num_inst));
}
####################
# REGRESSION METRICS
####################
# Calculate Mean Absolute Error loss
package AI::MXNet::MAE;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'mae');
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
zip(sub {
my ($label, $pred) = @_;
$label = $label->aspdl;
$pred = $pred->aspdl;
if($label->ndims == 1)
{
$label = $label->reshape(1, $label->shape->at(0));
}
$self->sum_metric($self->sum_metric + ($label - $pred)->abs->avg);
$self->num_inst($self->num_inst + 1);
}, $labels, $preds);
}
# Calculate Mean Squared Error loss
package AI::MXNet::MSE;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'mse');
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
zip(sub {
my ($label, $pred) = @_;
$label = $label->aspdl;
$pred = $pred->aspdl;
if($label->ndims == 1)
{
$label = $label->reshape(1, $label->shape->at(0));
}
$self->sum_metric($self->sum_metric + (($label - $pred)**2)->avg);
$self->num_inst($self->num_inst + 1);
}, $labels, $preds);
}
# Calculate Root Mean Squred Error loss
package AI::MXNet::RMSE;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'rmse');
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
zip(sub {
my ($label, $pred) = @_;
$label = $label->aspdl;
$pred = $pred->aspdl;
if($label->ndims == 1)
{
$label = $label->reshape(1, $label->shape->at(0));
}
$self->sum_metric($self->sum_metric + sqrt((($label - $pred)**2)->avg));
$self->num_inst($self->num_inst + 1);
}, $labels, $preds);
}
# Calculate Cross Entropy loss
package AI::MXNet::CrossEntropy;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'cross-entropy');
has 'eps' => (is => 'ro', isa => 'Num', default => 1e-8);
around BUILDARGS => sub {
my $orig = shift;
my $class = shift;
return $class->$orig(eps => $_[0]) if @_ == 1;
return $class->$orig(@_);
};
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
zip(sub {
my ($label, $pred) = @_;
$label = $label->aspdl->flat;
$pred = $pred->aspdl;
my $label_shape = $label->shape->at(0);
my $pred_shape = $pred->shape->at(-1);
confess(
"Size of label $label_shape and
.first dimension of pred $pred_shape do not match"
) unless $label_shape == $pred_shape;
my $prob = $pred->index($label);
$self->sum_metric($self->sum_metric + (-($prob + $self->eps)->log)->sum);
$self->num_inst($self->num_inst + $label_shape);
}, $labels, $preds);
}
package AI::MXNet::PearsonCorrelation;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name' => (default => 'pearson-correlation');
=head1 NAME
AI::MXNet::PearsonCorrelation
=cut
=head1 DESCRIPTION
Computes Pearson correlation.
Parameters
----------
name : str
Name of this metric instance for display.
Examples
--------
>>> $predicts = [mx->nd->array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
>>> $labels = [mx->nd->array([[1, 0], [0, 1], [0, 1]])]
>>> $pr = mx->metric->PearsonCorrelation()
>>> $pr->update($labels, $predicts)
>>> print pr->get()
('pearson-correlation', '0.421637061887229')
=cut
method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
AI::MXNet::Metric::check_label_shapes($labels, $preds);
zip(sub {
my ($label, $pred) = @_;
AI::MXNet::Metric::check_label_shapes($label, $pred);
$label = $label->aspdl->flat;
$pred = $pred->aspdl->flat;
my ($label_mean, $label_stdv) = ($label->stats)[0, 6];
my ($pred_mean, $pred_stdv) = ($pred->stats)[0, 6];
$self->sum_metric(
$self->sum_metric
+
((($label-$label_mean)*($pred-$pred_mean))->sum/$label->nelem)/(($label_stdv*$pred_stdv)->at(0))
);
$self->num_inst($self->num_inst + 1);
}, $labels, $preds);
}
=head1 DESCRIPTION
Custom evaluation metric that takes a sub ref.
Parameters
----------
eval_function : subref
( run in 0.967 second using v1.01-cache-2.11-cpan-39bf76dae61 )