AI-Categorizer

 view release on metacpan or  search on metacpan

lib/AI/Categorizer/Learner/Weka.pm  view on Meta::CPAN

package AI::Categorizer::Learner::Weka;

use strict;
use AI::Categorizer::Learner::Boolean;
use base qw(AI::Categorizer::Learner::Boolean);
use Params::Validate qw(:types);
use File::Spec;
use File::Copy;
use File::Path ();
use File::Temp ();

__PACKAGE__->valid_params
  (
   java_path => {type => SCALAR, default => 'java'},
   java_args => {type => SCALAR|ARRAYREF, optional => 1},
   weka_path => {type => SCALAR, optional => 1},
   weka_classifier => {type => SCALAR, default => 'weka.classifiers.NaiveBayes'},
   weka_args => {type => SCALAR|ARRAYREF, optional => 1},
   tmpdir => {type => SCALAR, default => File::Spec->tmpdir},
  );

__PACKAGE__->contained_objects
  (
   features => {class => 'AI::Categorizer::FeatureVector', delayed => 1},
  );

sub new {
  my $class = shift;
  my $self = $class->SUPER::new(@_);

  for ('java_args', 'weka_args') {
    $self->{$_} = [] unless defined $self->{$_};
    $self->{$_} = [$self->{$_}] unless UNIVERSAL::isa($self->{$_}, 'ARRAY');
  }
  
  if (defined $self->{weka_path}) {
    push @{$self->{java_args}}, '-classpath', $self->{weka_path};
    delete $self->{weka_path};
  }
  return $self;
}

# java -classpath /Applications/Science/weka-3-2-3/weka.jar weka.classifiers.NaiveBayes -t /tmp/train_file.arff -d /tmp/weka-machine

sub create_model {
  my ($self) = shift;
  my $m = $self->{model} ||= {};
  $m->{all_features} = [ $self->knowledge_set->features->names ];
  $m->{_in_dir} = File::Temp::tempdir( DIR => $self->{tmpdir} );

  # Create a dummy test file $dummy_file in ARFF format (a kludgey WEKA requirement)
  my $dummy_features = $self->create_delayed_object('features');
  $m->{dummy_file} = $self->create_arff_file("dummy", [[$dummy_features, 0]]);

  $self->SUPER::create_model(@_);
}

sub create_boolean_model {
  my ($self, $pos, $neg, $cat) = @_;

  my @docs = (map([$_->features, 1], @$pos),
	      map([$_->features, 0], @$neg));
  my $train_file = $self->create_arff_file($cat->name . '_train', \@docs);

  my %info = (machine_file => $cat->name . '_model');
  my $outfile = File::Spec->catfile($self->{model}{_in_dir}, $info{machine_file});

  my @args = ($self->{java_path},
	      @{$self->{java_args}},
	      $self->{weka_classifier}, 
	      @{$self->{weka_args}},
	      '-t', $train_file,
	      '-T', $self->{model}{dummy_file},
	      '-d', $outfile,
	      '-v',
	      '-p', '0',
	     );
  $self->do_cmd(@args);
  unlink $train_file or warn "Couldn't remove $train_file: $!";

  return \%info;
}

# java -classpath /Applications/Science/weka-3-2-3/weka.jar weka.classifiers.NaiveBayes -l out -T test.arff -p 0

sub get_boolean_score {
  my ($self, $doc, $info) = @_;
  
  # Create document file
  my $doc_file = $self->create_arff_file('doc', [[$doc->features, 0]], $self->{tmpdir});
  my $machine_file = File::Spec->catfile($self->{model}{_in_dir}, $info->{machine_file});

  my @args = ($self->{java_path},
	      @{$self->{java_args}},
	      $self->{weka_classifier},
	      '-l', $machine_file,
	      '-T', $doc_file,
	      '-p', 0,
	     );

  my @output = $self->do_cmd(@args);

  my %scores;
  foreach (@output) {
    # <doc> <category> <score> <real_category>
    # 0 large.elem 0.4515551620220952 numberth.high
    next unless my ($index, $predicted, $score) = /^([\d.]+)\s+(\S+)\s+([\d.]+)/;
    $scores{$predicted} = $score;
  }

  return $scores{1} || 0;  # Not sure what weka's scores represent...
}

sub categorize_collection {
  my ($self, %args) = @_;
  my $c = $args{collection} or die "No collection provided";
  
  my @alldocs;
  while (my $d = $c->next) {
    push @alldocs, $d;
  }
  my $doc_file = $self->create_arff_file("docs", [map [$_->features, 0], @alldocs]);

  my @assigned;
  
  my $l = $self->{model}{learners};
  foreach my $cat (keys %$l) {
    my $machine_file = File::Spec->catfile($self->{model}{_in_dir}, "${cat}_model");
    my @args = ($self->{java_path},
		@{$self->{java_args}},
		$self->{weka_classifier},
		'-l', $machine_file,
		'-T', $doc_file,
		'-p', 0,
               );

    my @output = $self->do_cmd(@args);

    foreach my $line (@output) {
      next unless $line =~ /\S/;
      
      # 0 large.elem 0.4515551620220952 numberth.high
      unless ( $line =~ /^([\d.]+)\s+(\S+)\s+([\d.]+)\s+(\S+)/ ) {
	warn "Can't parse line $line";



( run in 1.956 second using v1.01-cache-2.11-cpan-39bf76dae61 )