Algorithm-DecisionTree
view release on metacpan or search on metacpan
lib/Algorithm/RegressionTree.pm view on Meta::CPAN
sub recursive_descent_for_prediction {
my $self = shift;
my $node = shift;
my $feature_and_values = shift;
my $answer = shift;
my @children = @{$node->get_children()};
if (@children == 0) {
my $leaf_node_prediction = $node->node_prediction_from_features_and_values($feature_and_values);
$answer->{'prediction'} = $leaf_node_prediction;
push @{$answer->{'solution_path'}}, $node->get_serial_num();
return $answer;
}
my $feature_tested_at_node = $node->get_feature();
print "\nFeature tested at node for prediction: $feature_tested_at_node\n" if $self->{_debug3};
my $value_for_feature;
my $path_found;
my $pattern = '(\S+)\s*=\s*(\S+)';
my ($feature,$value);
foreach my $feature_and_value (@$feature_and_values) {
$feature_and_value =~ /$pattern/;
my ($feature,$value) = ($1, $2);
if ($feature eq $feature_tested_at_node) {
$value_for_feature = $value;
}
}
if (! defined $value_for_feature) {
my $leaf_node_prediction = $node->node_prediction_from_features_and_values($feature_and_values);
$answer->{'prediction'} = $leaf_node_prediction;
push @{$answer->{'solution_path'}}, $node->get_serial_num();
return $answer;
}
foreach my $child (@children) {
my @branch_features_and_values = @{$child->get_branch_features_and_values_or_thresholds()};
my $last_feature_and_value_on_branch = $branch_features_and_values[-1];
my $pattern1 = '(.+)<(.+)';
my $pattern2 = '(.+)>(.+)';
if ($last_feature_and_value_on_branch =~ /$pattern1/) {
my ($feature,$threshold) = ($1, $2);
if ($value_for_feature <= $threshold) {
$path_found = 1;
$answer = $self->recursive_descent_for_prediction($child, $feature_and_values, $answer);
push @{$answer->{'solution_path'}}, $node->get_serial_num();
last;
}
}
if ($last_feature_and_value_on_branch =~ /$pattern2/) {
my ($feature,$threshold) = ($1, $2);
if ($value_for_feature > $threshold) {
$path_found = 1;
$answer = $self->recursive_descent_for_prediction($child, $feature_and_values, $answer);
push @{$answer->{'solution_path'}}, $node->get_serial_num();
last;
}
}
}
return $answer if $path_found;
}
#-------------------------------------- Utility Methods ----------------------------------------
## This method is used to verify that you used legal feature names in the test
## sample that you want to classify with the decision tree.
sub _check_names_used {
my $self = shift;
my $features_and_values_test_data = shift;
my @features_and_values_test_data = @$features_and_values_test_data;
my $pattern = '(\S+)\s*=\s*(\S+)';
foreach my $feature_and_value (@features_and_values_test_data) {
$feature_and_value =~ /$pattern/;
my ($feature,$value) = ($1,$2);
die "Your test data has formatting error" unless defined($feature) && defined($value);
return 0 unless contained_in($feature, @{$self->{_feature_names}});
}
return 1;
}
sub display_all_plots {
my $self = shift;
my $ncols = $self->{_XMatrix}->cols;
unlink "regression_plots.png" if -e "regression_plots.png";
my $master_datafile = $self->{_training_datafile};
my $filename = basename($master_datafile);
my $temp_file = "__temp_" . $filename;
unlink $temp_file if -e $temp_file;
open OUTPUT, ">$temp_file"
or die "Unable to open a temp file in this directory: $!\n";
if ($ncols == 2) {
my @predictor_entries = $self->{_XMatrix}->col(0)->as_list;
my @dependent_val_vals = $self->{_YVector}->col(0)->as_list;
map {print OUTPUT "$predictor_entries[$_] $dependent_val_vals[$_]\n"} 0 .. $self->{_XMatrix}->rows - 1;
print OUTPUT "\n\n";
foreach my $plot (sort {$a <=> $b} keys %{$self->{_output_for_plots}}) {
map {print OUTPUT "$self->{_output_for_plots}->{$plot}->[0]->[$_] $self->{_output_for_plots}->{$plot}->[1]->[$_]\n"} 0 .. @{$self->{_output_for_plots}->{$plot}->[0]} - 1;
print OUTPUT "\n\n"
}
close OUTPUT;
my $gplot = Graphics::GnuplotIF->new( persist => 1 );
my $hardcopy_plot = Graphics::GnuplotIF->new();
$hardcopy_plot->gnuplot_cmd('set terminal png', "set output \"regression_plots.png\"");
$gplot->gnuplot_cmd( "set noclip" );
$gplot->gnuplot_cmd( "set pointsize 2" );
my $arg_string = "";
foreach my $i (0 .. scalar(keys %{$self->{_output_for_plots}})) {
if ($i == 0) {
$arg_string .= "\"$temp_file\" index $i using 1:2 notitle with points lt -1 pt 1, ";
} elsif ($i == 1) {
$arg_string .= "\"$temp_file\" index $i using 1:2 title \"linear regression\" with lines lt 1 lw 4, ";
} elsif ($i == 2) {
$arg_string .= "\"$temp_file\" index $i using 1:2 title \"tree regression\" with lines lt 3 lw 4, ";
} else {
$arg_string .= "\"$temp_file\" index $i using 1:2 notitle with lines lt 3 lw 4, ";
}
}
$arg_string = $arg_string =~ /^(.*),[ ]+$/;
$arg_string = $1;
$hardcopy_plot->gnuplot_cmd( "plot $arg_string" );
$gplot->gnuplot_cmd( "plot $arg_string" );
$gplot->gnuplot_pause(-1);
} elsif ($ncols == 3) {
my @dependent_val_vals = $self->{_YVector}->col(0)->as_list;
foreach my $i (0 .. $self->{_XMatrix}->rows - 1) {
( run in 1.129 second using v1.01-cache-2.11-cpan-13bb782fe5a )