Algorithm-DecisionTree

 view release on metacpan or  search on metacpan

lib/Algorithm/RegressionTree.pm  view on Meta::CPAN

    }
    if (! defined $value_for_feature) {
        my $leaf_node_prediction = $node->node_prediction_from_features_and_values($feature_and_values); 
        $answer->{'prediction'} = $leaf_node_prediction;
        push @{$answer->{'solution_path'}}, $node->get_serial_num();
        return $answer;
    }
    foreach my $child (@children) {
        my @branch_features_and_values = @{$child->get_branch_features_and_values_or_thresholds()};
        my $last_feature_and_value_on_branch = $branch_features_and_values[-1]; 
        my $pattern1 = '(.+)<(.+)';
        my $pattern2 = '(.+)>(.+)';
        if ($last_feature_and_value_on_branch =~ /$pattern1/) {
            my ($feature,$threshold) = ($1, $2);
            if ($value_for_feature <= $threshold) {
                $path_found = 1;
                $answer = $self->recursive_descent_for_prediction($child, $feature_and_values, $answer);
                push @{$answer->{'solution_path'}}, $node->get_serial_num();
                last;
            }
        }
        if ($last_feature_and_value_on_branch =~ /$pattern2/) {
            my ($feature,$threshold) = ($1, $2);
            if ($value_for_feature > $threshold) {
                $path_found = 1;
                $answer = $self->recursive_descent_for_prediction($child, $feature_and_values, $answer);
                push @{$answer->{'solution_path'}}, $node->get_serial_num();
                last;
            }
        }
    }
    return $answer if $path_found;
}

#--------------------------------------  Utility Methods   ----------------------------------------

##  This method is used to verify that you used legal feature names in the test
##  sample that you want to classify with the decision tree.
sub _check_names_used {
    my $self = shift;
    my $features_and_values_test_data = shift;
    my @features_and_values_test_data = @$features_and_values_test_data;
    my $pattern = '(\S+)\s*=\s*(\S+)';
    foreach my $feature_and_value (@features_and_values_test_data) {
        $feature_and_value =~ /$pattern/;
        my ($feature,$value) = ($1,$2);
        die "Your test data has formatting error" unless defined($feature) && defined($value);
        return 0 unless contained_in($feature, @{$self->{_feature_names}});
    }
    return 1;
}

sub display_all_plots {
    my $self = shift;
    my $ncols = $self->{_XMatrix}->cols;
    unlink "regression_plots.png" if -e "regression_plots.png";
    my $master_datafile = $self->{_training_datafile};
    my $filename = basename($master_datafile);
    my $temp_file = "__temp_" . $filename;
    unlink $temp_file if -e $temp_file;
    open OUTPUT, ">$temp_file"
           or die "Unable to open a temp file in this directory: $!\n";
    if ($ncols == 2) {
        my @predictor_entries = $self->{_XMatrix}->col(0)->as_list;
        my @dependent_val_vals = $self->{_YVector}->col(0)->as_list;
        map {print OUTPUT "$predictor_entries[$_] $dependent_val_vals[$_]\n"} 0 .. $self->{_XMatrix}->rows - 1;
        print OUTPUT "\n\n";
        foreach my $plot (sort {$a <=> $b} keys %{$self->{_output_for_plots}}) {
            map {print OUTPUT "$self->{_output_for_plots}->{$plot}->[0]->[$_] $self->{_output_for_plots}->{$plot}->[1]->[$_]\n"} 0 .. @{$self->{_output_for_plots}->{$plot}->[0]} - 1;
            print OUTPUT "\n\n"
        }
        close OUTPUT;
        my $gplot = Graphics::GnuplotIF->new( persist => 1 );
        my $hardcopy_plot = Graphics::GnuplotIF->new();
        $hardcopy_plot->gnuplot_cmd('set terminal png', "set output \"regression_plots.png\"");        
        $gplot->gnuplot_cmd( "set noclip" );
        $gplot->gnuplot_cmd( "set pointsize 2" );
        my $arg_string = "";
        foreach my $i (0 .. scalar(keys %{$self->{_output_for_plots}})) {
            if ($i == 0) {            
                $arg_string .= "\"$temp_file\" index $i using 1:2 notitle with points lt -1 pt 1, ";
            } elsif ($i == 1) {
                $arg_string .= "\"$temp_file\" index $i using 1:2 title \"linear regression\" with lines lt 1 lw 4, ";
            } elsif ($i == 2) {
                $arg_string .= "\"$temp_file\" index $i using 1:2 title \"tree regression\" with lines lt 3 lw 4, ";
            } else {
                $arg_string .= "\"$temp_file\" index $i using 1:2 notitle with lines lt 3 lw 4, ";
            }
        }
        $arg_string = $arg_string =~ /^(.*),[ ]+$/;
        $arg_string = $1;
        $hardcopy_plot->gnuplot_cmd( "plot $arg_string" );
        $gplot->gnuplot_cmd( "plot $arg_string" );
        $gplot->gnuplot_pause(-1);
    } elsif ($ncols == 3) {
        my @dependent_val_vals = $self->{_YVector}->col(0)->as_list;
        foreach my $i (0 .. $self->{_XMatrix}->rows - 1) {
            my @onerow = $self->{_XMatrix}->row($i)->as_list;
            pop @onerow;
            print OUTPUT "@onerow $dependent_val_vals[$i]\n";
        }
        print OUTPUT "\n\n";
        foreach my $plot (sort {$a <=> $b} keys %{$self->{_output_for_surface_plots}}) {
            my @plot_data = @{$self->{_output_for_surface_plots}->{$plot}};
            my @predictors = @{$plot_data[0]};
            my @predictions = @{$plot_data[1]};
            map {print OUTPUT "$predictors[$_] $predictions[$_]\n"} 0 .. @predictions - 1;
            print OUTPUT "\n\n"
        }
        close OUTPUT;
        my $gplot = Graphics::GnuplotIF->new( persist => 1 );
        my $hardcopy_plot = Graphics::GnuplotIF->new();
        $hardcopy_plot->gnuplot_cmd('set terminal png', "set output \"regression_plots.png\"");        
        $gplot->gnuplot_cmd( "set noclip" );
        $gplot->gnuplot_cmd( "set pointsize 2" );
        my $arg_string = "";
        foreach my $i (0 .. scalar(keys %{$self->{_output_for_surface_plots}})) {
            if ($i == 0) {            
                $arg_string .= "\"$temp_file\" index $i using 1:2:3 notitle with points lt -1 pt 1, ";
            } elsif ($i == 1) {
                $arg_string .= "\"$temp_file\" index $i using 1:2:3 title \"linear regression\" with points lt 1 pt 2, ";
            } elsif ($i == 2) {
                $arg_string .= "\"$temp_file\" index $i using 1:2:3 title \"tree regression\" with points lt 3 pt 3, ";
            } else {
                $arg_string .= "\"$temp_file\" index $i using 1:2:3 notitle with points lt 3 pt 3, ";
            }
        }
        $arg_string = $arg_string =~ /^(.*),[ ]+$/;
        $arg_string = $1;
        $hardcopy_plot->gnuplot_cmd( "splot $arg_string" );
        $gplot->gnuplot_cmd( "splot $arg_string" );
        $gplot->gnuplot_pause(-1);
    } else {
        die "no visual displays for regression from more then 2 predictor vars";
    }   
}  

sub DESTROY {
    unlink glob "__temp_*";
}

############################################## Utility Routines ##########################################
# checks whether an element is in an array:
sub contained_in {
    my $ele = shift;
    my @array = @_;
    my $count = 0;
    map {$count++ if $ele eq $_} @array;
    return $count;
}

sub minmax {
    my $arr = shift;
    my ($min, $max);
    foreach my $i (0..@{$arr}-1) {
        if ( (!defined $min) || ($arr->[$i] < $min) ) {
            $min = $arr->[$i];
        }
        if ( (!defined $max) || ($arr->[$i] > $max) ) {
            $max = $arr->[$i];
        }
    }
    return ($min, $max);
}

sub sample_index {
    my $arg = shift;
    $arg =~ /_(.+)$/;
    return $1;
}    



( run in 0.471 second using v1.01-cache-2.11-cpan-e1769b4cff6 )