AI-MXNet

 view release on metacpan or  search on metacpan

lib/AI/MXNet/Initializer.pm  view on Meta::CPAN

method _init_loc_bias($name, $arr)
{
    confess("assert error shape[0] == 6")
        unless $arr->shape->[0] == 6;
    $arr .= [1.0, 0, 0, 0, 1.0, 0];
}

method _init_zero($name, $arr)
{
    $arr .= 0;
}

method _init_one($name, $arr)
{
    $arr .= 1;
}

method _init_bias($name, $arr)
{
    $arr .= 0;
}

method _init_gamma($name, $arr)
{
    $arr .= 1;
}

method _init_beta($name, $arr)
{
    $arr .= 0;
}

method _init_weight($name, $arr)
{
    confess("Virtual method, subclass must override it");
}

method _init_default($name, $arr)
{
    confess(
        "Unknown initialization pattern for $name. "
        .'Default initialization is now limited to '
        .'"weight", "bias", "gamma" (1.0), and "beta" (0.0).'
        .'Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern'
    );
}

=head1 NAME

    AI::MXNet::Load  - Initialize by loading a pretrained param from a hash ref.
=cut

=head2 new

    Parameters
    ----------
    param: HashRef[AI::MXNet::NDArray]
    default_init: Initializer
        default initializer when a name is not found in the param hash ref.
    verbose: bool
    log the names when initializing.
=cut

package AI::MXNet::Load;
use Mouse;
extends 'AI::MXNet::Initializer';

has 'param'        => (is => "rw", isa => 'HashRef[AI::MXNet::NDArray]', required => 1);
has 'default_init' => (is => "rw", isa => "AI::MXNet::Initializer");
has 'verbose'      => (is => "rw", isa => "Int", default => 0);

sub BUILD
{
    my $self = shift;
    my $param = AI::MXNet::NDArray->load($self->param) unless ref $self->param;
    my %self_param;
    while(my ($name, $arr) = each %{ $self->param })
    {
        $name =~ s/^(?:arg|aux)://;
        $self_param{ $name } = $arr;
    }
    $self->param(\%self_param);
}

method call(Str $name, AI::MXNet::NDArray $arr)
{
    if(exists $self->param->{ $name })
    {
        my $target_shape = join(',', @{ $arr->shape });
        my $param_shape  = join(',', @{ $self->param->{ $name }->shape });
        confess(
            "Parameter $name cannot be initialized from loading. "
            ."Shape mismatch, target $target_shape vs loaded $param_shape"
        ) unless $target_shape eq $param_shape;
        $arr .= $self->param->{ $name };
        AI::MXNet::Log->info("Initialized $name by loading") if $self->verbose;
    }
    else
    {
        confess(
            "Cannot Initialize $name. Not found in loaded param "
            ."and no default Initializer is provided."
        ) unless defined $self->default_init;
        $self->default_init($name, $arr);
        AI::MXNet::Log->info("Initialized $name by default") if $self->verbose;
    }
}

*slice = *call;

=head1 NAME

    AI::MXNet::Mixed - A container for multiple initializer patterns.
=cut

=head2 new

    patterns: array ref of str
        array ref of regular expression patterns to match parameter names.
    initializers: array ref of AI::MXNet::Initializer objects.
        array ref of Initializers corresponding to the patterns.



( run in 1.055 second using v1.01-cache-2.11-cpan-39bf76dae61 )