AI-DecisionTree

 view release on metacpan or  search on metacpan

lib/AI/DecisionTree.pm  view on Meta::CPAN

		purge => 1,
		verbose => 0,
		max_depth => 0,
		@_,
		nodes => 0,
		instances => [],
		name_gen => 0,
	       }, $package;
}

sub nodes      { $_[0]->{nodes} }
sub noise_mode { $_[0]->{noise_mode} }
sub depth      { $_[0]->{depth} }

sub add_instance {
  my ($self, %args) = @_;
  croak "Missing 'attributes' parameter" unless $args{attributes};
  croak "Missing 'result' parameter" unless defined $args{result};
  $args{name} = $self->{name_gen}++ unless exists $args{name};
  
  my @attributes;
  while (my ($k, $v) = each %{$args{attributes}}) {
    $attributes[ _hlookup($self->{attributes}, $k) ] = _hlookup($self->{attribute_values}{$k}, $v);
  }
  $_ ||= 0 foreach @attributes;
  
  push @{$self->{instances}}, AI::DecisionTree::Instance->new(\@attributes, _hlookup($self->{results}, $args{result}), $args{name});
}

sub _hlookup {
  $_[0] ||= {}; # Autovivify as a hash
  my ($hash, $key) = @_;
  unless (exists $hash->{$key}) {
    $hash->{$key} = 1 + keys %$hash;
  }
  return $hash->{$key};
}

sub _create_lookup_hashes {
  my $self = shift;
  my $h = $self->{results};
  $self->{results_reverse} = [ undef, sort {$h->{$a} <=> $h->{$b}} keys %$h ];
  
  foreach my $attr (keys %{$self->{attribute_values}}) {
    my $h = $self->{attribute_values}{$attr};
    $self->{attribute_values_reverse}{$attr} = [ undef, sort {$h->{$a} <=> $h->{$b}} keys %$h ];
  }
}

sub train {
  my ($self, %args) = @_;
  if (not @{ $self->{instances} }) {
    croak "Training data has been purged, can't re-train" if $self->{tree};
    croak "Must add training instances before calling train()";
  }
  
  $self->_create_lookup_hashes;
  local $self->{curr_depth} = 0;
  local $self->{max_depth} = $args{max_depth} if exists $args{max_depth};
  $self->{depth} = 0;
  $self->{tree} = $self->_expand_node( instances => $self->{instances} );
  $self->{total_instances} = @{$self->{instances}};
  
  $self->prune_tree if $self->{prune};
  $self->do_purge if $self->purge;
  return 1;
}

sub do_purge {
  my $self = shift;
  delete @{$self}{qw(instances attribute_values attribute_values_reverse results results_reverse)};
}

sub copy_instances {
  my ($self, %opt) = @_;
  croak "Missing 'from' parameter to copy_instances()" unless exists $opt{from};
  my $other = $opt{from};
  croak "'from' parameter is not a decision tree" unless UNIVERSAL::isa($other, __PACKAGE__);

  foreach (qw(instances attributes attribute_values results)) {
    $self->{$_} = $other->{$_};
  }
  $self->_create_lookup_hashes;
}

sub set_results {
  my ($self, $hashref) = @_;
  foreach my $instance (@{$self->{instances}}) {
    my $name = $instance->name;
    croak "No result given for instance '$name'" unless exists $hashref->{$name};
    $instance->set_result( $self->{results}{ $hashref->{$name} } );
  }
}

sub instances { $_[0]->{instances} }

sub purge {
  my $self = shift;
  $self->{purge} = shift if @_;
  return $self->{purge};
}

# Each node contains:
#  { split_on => $attr_name,
#    children => { $attr_value1 => $node1,
#                  $attr_value2 => $node2, ... }
#  }
# or
#  { result => $result }

sub _expand_node {
  my ($self, %args) = @_;
  my $instances = $args{instances};
  print STDERR '.' if $self->{verbose};
  
  $self->{depth} = $self->{curr_depth} if $self->{curr_depth} > $self->{depth};
  local $self->{curr_depth} = $self->{curr_depth} + 1;
  $self->{nodes}++;

  my %results;
  $results{$self->_result($_)}++ foreach @$instances;
  my @results = map {$_,$results{$_}} sort {$results{$b} <=> $results{$a}} keys %results;
  my %node = ( distribution => \@results, instances => scalar @$instances );

  foreach (keys %results) {
    $self->{prior_freqs}{$_} += $results{$_};
  }

  if (keys(%results) == 1) {
    # All these instances have the same result - make this node a leaf
    $node{result} = $self->_result($instances->[0]);
    return \%node;
  }
  
  # Multiple values are present - find the best predictor attribute and split on it
  my $best_attr = $self->best_attr($instances);

  croak "Inconsistent data, can't build tree with noise_mode='fatal'"
    if $self->{noise_mode} eq 'fatal' and !defined $best_attr;

  if ( !defined($best_attr)
       or $self->{max_depth} && $self->{curr_depth} > $self->{max_depth} ) {
    # Pick the most frequent result for this leaf
    $node{result} = (sort {$results{$b} <=> $results{$a}} keys %results)[0];
    return \%node;
  }
  
  $node{split_on} = $best_attr;
  
  my %split;
  foreach my $i (@$instances) {
    my $v = $self->_value($i, $best_attr);
    push @{$split{ defined($v) ? $v : '<undef>' }}, $i;
  }
  die ("Something's wrong: attribute '$best_attr' didn't split ",
       scalar @$instances, " instances into multiple buckets (@{[ keys %split ]})")
    unless keys %split > 1;

  foreach my $value (keys %split) {
    $node{children}{$value} = $self->_expand_node( instances => $split{$value} );
  }
  
  return \%node;
}

sub best_attr {
  my ($self, $instances) = @_;

  # 0 is a perfect score, entropy(#instances) is the worst possible score
  
  my ($best_score, $best_attr) = (@$instances * $self->entropy( map $_->result_int, @$instances ), undef);
  my $all_attr = $self->{attributes};
  foreach my $attr (keys %$all_attr) {

    # %tallies is correlation between each attr value and result
    # %total is number of instances with each attr value
    my (%totals, %tallies);
    my $num_undef = AI::DecisionTree::Instance::->tally($instances, \%tallies, \%totals, $all_attr->{$attr});
    next unless keys %totals; # Make sure at least one instance defines this attribute
    
    my $score = 0;
    while (my ($opt, $vals) = each %tallies) {
      $score += $totals{$opt} * $self->entropy2( $vals, $totals{$opt} );
    }

    ($best_attr, $best_score) = ($attr, $score) if $score < $best_score;
  }
  
  return $best_attr;
}

sub entropy2 {
  shift;
  my ($counts, $total) = @_;

  # Entropy is defined with log base 2 - we just divide by log(2) at the end to adjust.
  my $sum = 0;
  $sum += $_ * log($_) foreach values %$counts;
  return +(log($total) - $sum/$total)/log(2);
}

sub entropy {
  shift;

  my %count;
  $count{$_}++ foreach @_;

  # Entropy is defined with log base 2 - we just divide by log(2) at the end to adjust.
  my $sum = 0;
  $sum += $_ * log($_) foreach values %count;
  return +(log(@_) - $sum/@_)/log(2);
}

sub prune_tree {
  my $self = shift;

  # We use a minimum-description-length approach.  We calculate the
  # score of each node:
  #  n = number of nodes below
  #  r = number of results (categories) in the entire tree



( run in 1.224 second using v1.01-cache-2.11-cpan-97f6503c9c8 )