AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Metric.pm  view on Meta::CPAN

    ) unless $pred_shape == $label_shape;
}

=head1 DESCRIPTION

    Base class of all evaluation metrics.
=cut

package AI::MXNet::EvalMetric;
use Mouse;
use overload '""' => sub {
    return "EvalMetric: "
            .Data::Dumper->new(
                [shift->get_name_value()]
            )->Purity(1)->Deepcopy(1)->Terse(1)->Dump
},  fallback => 1;
has 'name'       => (is => 'rw', isa => 'Str');
has 'num'        => (is => 'rw', isa => 'Int');
has 'num_inst'   => (is => 'rw', isa => 'Maybe[Int|ArrayRef[Int]]');
has 'sum_metric' => (is => 'rw', isa => 'Maybe[Num|ArrayRef[Num]]');

sub BUILD
{
    shift->reset;
}

method update($label, $pred)
{
    confess('NotImplemented');
}

method reset()
{
    if(not defined $self->num)
    {
        $self->num_inst(0);
        $self->sum_metric(0);
    }
    else
    {
        $self->num_inst([(0) x $self->num]);
        $self->sum_metric([(0) x $self->num]);
    }
}

method get()
{
    if(not defined $self->num)
    {
        if($self->num_inst == 0)
        {
            return ($self->name, 'nan');
        }
        else
        {
            return ($self->name, $self->sum_metric / $self->num_inst);
        }
    }
    else
    {
        my $names = [map { sprintf('%s_%d', $self->name, $_) } 0..$self->num-1];
        my $values = [];
        for (my $i = 0; $i < @{ $self->sum_metric }; $i++)
        {
            my ($x, $y) = ($self->sum_metric->[$i], $self->num_inst->[$i]);
            if($y != 0)
            {
                push (@$values, $x/$y);
            }
            else
            {
                push (@$values, 'nan');
            }
        }
        return ($names, $values);
    }
}

method get_name_value()
{
    my ($name, $value) = $self->get;
    $name = [$name] unless ref $name;
    $value = [$value] unless ref $value;
    my %ret;
    @ret{ @$name } = @$value;
    return \%ret;
}

package AI::MXNet::CompositeEvalMetric;
use Mouse;

extends 'AI::MXNet::EvalMetric';
has 'metrics' => (is => 'rw', isa => 'ArrayRef[AI::MXNet::EvalMetric]', default => sub { [] });
has '+name'   => (default => 'composite');

# Add a child metric.
method add(AI::MXNet::EvalMetric $metric)
{
    push @{ $self->metrics }, $metric;
}

# Get a child metric.
method get_metric(int $index)
{
    my $max = @{ $self->metrics } - 1;
    confess("Metric index $index is out of range 0 and $max")
        if $index > $max;
    return $self->metrics->[$index];
}

method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
    for my $metric (@{ $self->metrics })
    {
        $metric->update($labels, $preds);
    }
}

method reset()
{
    for my $metric (@{ $self->metrics })

lib/AI/MXNet/Metric.pm  view on Meta::CPAN


# Calculate Cross Entropy loss
package AI::MXNet::CrossEntropy;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name'   => (default => 'cross-entropy');
has 'eps'     => (is => 'ro', isa => 'Num', default => 1e-8);
around BUILDARGS => sub {
    my $orig  = shift;
    my $class = shift;
    return $class->$orig(eps => $_[0]) if @_ == 1;
    return $class->$orig(@_);
};

method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
    AI::MXNet::Metric::check_label_shapes($labels, $preds);
    zip(sub {
        my ($label, $pred) = @_;
        $label = $label->aspdl->flat;
        $pred =  $pred->aspdl;
        my $label_shape = $label->shape->at(0);
        my $pred_shape  = $pred->shape->at(-1);
        confess(
            "Size of label  $label_shape and 
            .first dimension of pred $pred_shape do not match"
        ) unless $label_shape == $pred_shape;
        my $prob = $pred->index($label);
        $self->sum_metric($self->sum_metric + (-($prob + $self->eps)->log)->sum);
        $self->num_inst($self->num_inst + $label_shape);
    }, $labels, $preds);
}

package AI::MXNet::PearsonCorrelation;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has '+name'   => (default => 'pearson-correlation');

=head1 NAME

    AI::MXNet::PearsonCorrelation
=cut

=head1 DESCRIPTION

    Computes Pearson correlation.

    Parameters
    ----------
    name : str
        Name of this metric instance for display.

    Examples
    --------
    >>> $predicts = [mx->nd->array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
    >>> $labels   = [mx->nd->array([[1, 0], [0, 1], [0, 1]])]
    >>> $pr = mx->metric->PearsonCorrelation()
    >>> $pr->update($labels, $predicts)
    >>> print pr->get()
    ('pearson-correlation', '0.421637061887229')
=cut

method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
    AI::MXNet::Metric::check_label_shapes($labels, $preds);
    zip(sub {
        my ($label, $pred) = @_;
        AI::MXNet::Metric::check_label_shapes($label, $pred);
        $label = $label->aspdl->flat;
        $pred  = $pred->aspdl->flat;
        my ($label_mean, $label_stdv) = ($label->stats)[0, 6];
        my ($pred_mean, $pred_stdv) = ($pred->stats)[0, 6];
        $self->sum_metric(
            $self->sum_metric
                +
            ((($label-$label_mean)*($pred-$pred_mean))->sum/$label->nelem)/(($label_stdv*$pred_stdv)->at(0))
        );
        $self->num_inst($self->num_inst + 1);
    }, $labels, $preds);
}

=head1 DESCRIPTION

    Custom evaluation metric that takes a sub ref.

    Parameters
    ----------
    eval_function : subref
        Customized evaluation function.
    name : str, optional
        The name of the metric
    allow_extra_outputs : bool
        If true, the prediction outputs can have extra outputs.
        This is useful in RNN, where the states are also produced
        in outputs for forwarding.
=cut


package AI::MXNet::CustomMetric;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::EvalMetric';
has 'eval_function'       => (is => 'ro', isa => 'CodeRef');
has 'allow_extra_outputs' => (is => 'ro', isa => 'Int', default => 0);

method update(ArrayRef[AI::MXNet::NDArray] $labels, ArrayRef[AI::MXNet::NDArray] $preds)
{
    AI::MXNet::Metric::check_label_shapes($labels, $preds)
        unless $self->allow_extra_outputs;
    zip(sub {
        my ($label, $pred) = @_;
        $label = $label->aspdl;
        $pred =  $pred->aspdl;
        my $value = $self->eval_function->($label, $pred);
        my $sum_metric = ref $value ? $value->[0] : $value;
        my $num_inst   = ref $value ? $value->[1] : 1;
        $self->sum_metric($self->sum_metric + $sum_metric);
        $self->num_inst($self->num_inst + $num_inst);
    }, $labels, $preds);



( run in 1.207 second using v1.01-cache-2.11-cpan-39bf76dae61 )