AI-TensorFlow-Libtensorflow

 view release on metacpan or  search on metacpan

lib/AI/TensorFlow/Libtensorflow/Manual/Notebook/InferenceUsingTFHubEnformerGeneExprPredModel.pod  view on Meta::CPAN

    outputs_mouse  =>
        AI::TensorFlow::Libtensorflow::Output->New({
            oper => $graph->OperationByName('StatefulPartitionedCall'),
            index => 1,
    }),
);

p %puts;

my $predict_on_batch = sub {
    my ($session, $t) = @_;
    my @outputs_t;

    $session->Run(
        undef,
        [$puts{inputs_args_0}], [$t],
        [$puts{outputs_human}], \@outputs_t,
        undef,
        undef,
        $s
    );
    AssertOK($s);

    return $outputs_t[0];
};

undef;

use PDL;

our $SHOW_ENCODER = 1;

sub one_hot_dna {
    my ($seq) = @_;

    my $from_alphabet = "NACGT";
    my $to_alphabet   = pack "C*", 0..length($from_alphabet)-1;

    # sequences from UCSC genome have both uppercase and lowercase bases
    my $from_alphabet_tr = $from_alphabet . lc $from_alphabet;
    my $to_alphabet_tr   = $to_alphabet x 2;

    my $p = zeros(byte, bytes::length($seq));
    my $p_dataref = $p->get_dataref;
    ${ $p_dataref } = $seq;
    eval "tr/$from_alphabet_tr/$to_alphabet_tr/" for ${ $p_dataref };
    $p->upd_data;

    my $encoder = append(float(0), identity(float(length($from_alphabet)-1)) );
    say "Encoder is\n", $encoder->info, $encoder if $SHOW_ENCODER;

    my $encoded  = $encoder->index( $p->dummy(0) );

    return $encoded;
}

####

{

say "Testing one-hot encoding:\n";

my $onehot_test_seq = "ACGTNtgcan";
my $test_encoded = one_hot_dna( $onehot_test_seq );
$SHOW_ENCODER = 0;

say "One-hot encoding of sequence '$onehot_test_seq' is:";
say $test_encoded->info, $test_encoded;

}

package Interval {
    use Bio::Location::Simple ();

    use parent qw(Bio::Location::Simple);

    sub center {
        my $self = shift;
        my $center = int( ($self->start + $self->end ) / 2 );
        my $delta = ($self->start + $self->end ) % 2;
        return $center + $delta;
    }

    sub resize {
        my ($self, $width) = @_;
        my $new_interval = $self->clone;

        my $center = $self->center;
        my $half   = int( ($width-1) / 2 );
        my $offset = ($width-1) % 2;

        $new_interval->start( $center - $half - $offset );
        $new_interval->end(   $center + $half  );

        return $new_interval;
    }

    use overload '""' => \&_op_stringify;

    sub _op_stringify { sprintf "%s:%s", $_[0]->seq_id // "(no sequence)", $_[0]->to_FTstring }
}

#####

{

say "Testing interval resizing:\n";
sub _debug_resize {
    my ($interval, $to, $msg) = @_;

    my $resized_interval = $interval->resize($to);

    die "Wrong interval size for $interval --($to)--> $resized_interval"
        unless $resized_interval->length == $to;

    say sprintf "Interval: %s -> %s, length %2d : %s",
        $interval,
        $resized_interval, $resized_interval->length,
        $msg;
}

for my $interval_spec ( [4, 8], [5, 8], [5, 9], [6, 9]) {
    my ($start, $end) = @$interval_spec;
    my $test_interval = Interval->new( -seq_id => 'chr11', -start => $start, -end => $end );
    say sprintf "Testing interval %s with length %d", $test_interval, $test_interval->length;
    say "-----";
    for(0..5) {
        my $base = $test_interval->length;
        my $to = $base + $_;
        _debug_resize $test_interval, $to, "$base -> $to (+ $_)";
    }
    say "";
}

}

undef;

use Bio::DB::HTS::Faidx;

my $hg_db = Bio::DB::HTS::Faidx->new( $hg_bgz_path );

sub extract_sequence {
    my ($db, $interval) = @_;

    my $chrom_length = $db->length($interval->seq_id);

    my $trimmed_interval = $interval->clone;
    $trimmed_interval->start( List::Util::max( $interval->start, 1               ) );
    $trimmed_interval->end(   List::Util::min( $interval->end  , $chrom_length   ) );

    # Bio::DB::HTS::Faidx is 0-based for both start and end points
    my $seq = $db->get_sequence2_no_length(
        $trimmed_interval->seq_id,
        $trimmed_interval->start - 1,
        $trimmed_interval->end   - 1,
    );

    my $pad_upstream   = 'N' x List::Util::max( -($interval->start-1), 0 );
    my $pad_downstream = 'N' x List::Util::max( $interval->end - $chrom_length, 0 );

    return join '', $pad_upstream, $seq, $pad_downstream;
}

sub seq_info {
    my ($seq, $n) = @_;
    $n ||= 10;
    if( length $seq > $n ) {
        sprintf "%s...%s (length %d)", uc substr($seq, 0, $n), uc substr($seq, -$n), length $seq;
    } else {
        sprintf "%s (length %d)", uc $seq, length $seq;
    }
}

####

{

say "Testing sequence extraction:";

say "1 base: ",   seq_info
    extract_sequence( $hg_db,
        Interval->new( -seq_id => 'chr11',
            -start => 35_082_742 + 1,
            -end   => 35_082_742 + 1 ) );

say "3 bases: ",  seq_info
    extract_sequence( $hg_db,
        Interval->new( -seq_id => 'chr11',
            -start => 1,
            -end   => 1 )->resize(3) );

say "5 bases: ", seq_info
    extract_sequence( $hg_db,
        Interval->new( -seq_id => 'chr11',
            -start => $hg_db->length('chr11'),
            -end   => $hg_db->length('chr11') )->resize(5) );

say "chr11 is of length ", $hg_db->length('chr11');
say "chr11 bases: ", seq_info
    extract_sequence( $hg_db,
        Interval->new( -seq_id => 'chr11',
            -start => 1,
            -end   => $hg_db->length('chr11') )->resize( $hg_db->length('chr11') ) );
}

my $target_interval = Interval->new( -seq_id => 'chr11',
    -start => 35_082_742 +  1, # BioPerl is 1-based
    -end   => 35_197_430 );

say "Target interval: $target_interval with length @{[ $target_interval->length ]}";

die "Target interval is not $model_central_base_pairs_length bp long"
    unless $target_interval->length == $model_central_base_pairs_length;

say "Target sequence is ", seq_info extract_sequence( $hg_db, $target_interval );


say "";


my $resized_interval = $target_interval->resize( $model_sequence_length );
say "Resized interval: $resized_interval with length @{[ $resized_interval->length ]}";

die "resize() is not working properly!" unless $resized_interval->length == $model_sequence_length;

my $seq = extract_sequence( $hg_db, $resized_interval );

say "Resized sequence is ", seq_info($seq);

my $sequence_one_hot = one_hot_dna( $seq )->dummy(-1);

say $sequence_one_hot->info; undef;

use Devel::Timer;
my $t = Devel::Timer->new;

$t->mark('prediction of sequence');

lib/AI/TensorFlow/Libtensorflow/Manual/Notebook/InferenceUsingTFHubEnformerGeneExprPredModel.pod  view on Meta::CPAN

          [$puts{outputs_human}], \@outputs_t,
          undef,
          undef,
          $s
      );
      AssertOK($s);
  
      return $outputs_t[0];
  };
  
  undef;

=head2 Encoding the data

The model specifies that the way to get a sequence of DNA bases into a C<TFTensor> is to use L<one-hot encoding|https://en.wikipedia.org/wiki/One-hot#Machine_learning_and_statistics> in the order C<ACGT>.

This means that the bases are represented as vectors of length 4:

| base | vector encoding |
|------|-----------------|
| A    | C<[1 0 0 0]>     |
| C    | C<[0 1 0 0]>     |
| G    | C<[0 0 1 0]>     |
| T    | C<[0 0 0 1]>     |
| N    | C<[0 0 0 0]>     |

We can achieve this encoding by creating a lookup table with a PDL ndarray. This could be done by creating a byte PDL ndarray of dimensions C<[ 256 4 ]> to directly look up the the numeric value of characters 0-255, but here we'll go with a smaller C...

  use PDL;
  
  our $SHOW_ENCODER = 1;
  
  sub one_hot_dna {
      my ($seq) = @_;
  
      my $from_alphabet = "NACGT";
      my $to_alphabet   = pack "C*", 0..length($from_alphabet)-1;
  
      # sequences from UCSC genome have both uppercase and lowercase bases
      my $from_alphabet_tr = $from_alphabet . lc $from_alphabet;
      my $to_alphabet_tr   = $to_alphabet x 2;
  
      my $p = zeros(byte, bytes::length($seq));
      my $p_dataref = $p->get_dataref;
      ${ $p_dataref } = $seq;
      eval "tr/$from_alphabet_tr/$to_alphabet_tr/" for ${ $p_dataref };
      $p->upd_data;
  
      my $encoder = append(float(0), identity(float(length($from_alphabet)-1)) );
      say "Encoder is\n", $encoder->info, $encoder if $SHOW_ENCODER;
  
      my $encoded  = $encoder->index( $p->dummy(0) );
  
      return $encoded;
  }
  
  ####
  
  {
  
  say "Testing one-hot encoding:\n";
  
  my $onehot_test_seq = "ACGTNtgcan";
  my $test_encoded = one_hot_dna( $onehot_test_seq );
  $SHOW_ENCODER = 0;
  
  say "One-hot encoding of sequence '$onehot_test_seq' is:";
  say $test_encoded->info, $test_encoded;
  
  }

B<STREAM (STDOUT)>:

  Testing one-hot encoding:
  
  Encoder is
  PDL: Float D [5,4]
  [
   [0 1 0 0 0]
   [0 0 1 0 0]
   [0 0 0 1 0]
   [0 0 0 0 1]
  ]
  
  One-hot encoding of sequence 'ACGTNtgcan' is:
  PDL: Float D [4,10]
  [
   [1 0 0 0]
   [0 1 0 0]
   [0 0 1 0]
   [0 0 0 1]
   [0 0 0 0]
   [0 0 0 1]
   [0 0 1 0]
   [0 1 0 0]
   [1 0 0 0]
   [0 0 0 0]
  ]

B<RESULT>:

  1

Note that in the above, the PDL ndarray's

=over

=item *

first dimension is 4 which matches the last dimension of the input C<TFTensor>;

=item *

second dimension is the sequence length which matches the penultimate dimension of the input C<TFTensor>.

=back

Now we need a way to deal with the sequence interval. We're going to use 1-based coordinates as BioPerl does. In fact, we'll extend a BioPerl class.

  package Interval {
      use Bio::Location::Simple ();
  
      use parent qw(Bio::Location::Simple);
  
      sub center {
          my $self = shift;
          my $center = int( ($self->start + $self->end ) / 2 );
          my $delta = ($self->start + $self->end ) % 2;
          return $center + $delta;
      }
  
      sub resize {
          my ($self, $width) = @_;
          my $new_interval = $self->clone;
  
          my $center = $self->center;
          my $half   = int( ($width-1) / 2 );
          my $offset = ($width-1) % 2;
  
          $new_interval->start( $center - $half - $offset );
          $new_interval->end(   $center + $half  );
  
          return $new_interval;
      }
  
      use overload '""' => \&_op_stringify;
  
      sub _op_stringify { sprintf "%s:%s", $_[0]->seq_id // "(no sequence)", $_[0]->to_FTstring }
  }
  
  #####
  
  {
  
  say "Testing interval resizing:\n";
  sub _debug_resize {
      my ($interval, $to, $msg) = @_;
  
      my $resized_interval = $interval->resize($to);
  
      die "Wrong interval size for $interval --($to)--> $resized_interval"
          unless $resized_interval->length == $to;
  
      say sprintf "Interval: %s -> %s, length %2d : %s",
          $interval,
          $resized_interval, $resized_interval->length,
          $msg;
  }
  
  for my $interval_spec ( [4, 8], [5, 8], [5, 9], [6, 9]) {
      my ($start, $end) = @$interval_spec;
      my $test_interval = Interval->new( -seq_id => 'chr11', -start => $start, -end => $end );
      say sprintf "Testing interval %s with length %d", $test_interval, $test_interval->length;
      say "-----";
      for(0..5) {
          my $base = $test_interval->length;
          my $to = $base + $_;
          _debug_resize $test_interval, $to, "$base -> $to (+ $_)";
      }
      say "";
  }
  
  }
  
  undef;

B<STREAM (STDOUT)>:

  Testing interval resizing:
  
  Testing interval chr11:4..8 with length 5
  -----
  Interval: chr11:4..8 -> chr11:4..8, length  5 : 5 -> 5 (+ 0)
  Interval: chr11:4..8 -> chr11:3..8, length  6 : 5 -> 6 (+ 1)
  Interval: chr11:4..8 -> chr11:3..9, length  7 : 5 -> 7 (+ 2)
  Interval: chr11:4..8 -> chr11:2..9, length  8 : 5 -> 8 (+ 3)
  Interval: chr11:4..8 -> chr11:2..10, length  9 : 5 -> 9 (+ 4)
  Interval: chr11:4..8 -> chr11:1..10, length 10 : 5 -> 10 (+ 5)
  
  Testing interval chr11:5..8 with length 4
  -----
  Interval: chr11:5..8 -> chr11:5..8, length  4 : 4 -> 4 (+ 0)
  Interval: chr11:5..8 -> chr11:5..9, length  5 : 4 -> 5 (+ 1)
  Interval: chr11:5..8 -> chr11:4..9, length  6 : 4 -> 6 (+ 2)
  Interval: chr11:5..8 -> chr11:4..10, length  7 : 4 -> 7 (+ 3)
  Interval: chr11:5..8 -> chr11:3..10, length  8 : 4 -> 8 (+ 4)
  Interval: chr11:5..8 -> chr11:3..11, length  9 : 4 -> 9 (+ 5)
  
  Testing interval chr11:5..9 with length 5
  -----
  Interval: chr11:5..9 -> chr11:5..9, length  5 : 5 -> 5 (+ 0)
  Interval: chr11:5..9 -> chr11:4..9, length  6 : 5 -> 6 (+ 1)
  Interval: chr11:5..9 -> chr11:4..10, length  7 : 5 -> 7 (+ 2)
  Interval: chr11:5..9 -> chr11:3..10, length  8 : 5 -> 8 (+ 3)
  Interval: chr11:5..9 -> chr11:3..11, length  9 : 5 -> 9 (+ 4)
  Interval: chr11:5..9 -> chr11:2..11, length 10 : 5 -> 10 (+ 5)
  
  Testing interval chr11:6..9 with length 4
  -----
  Interval: chr11:6..9 -> chr11:6..9, length  4 : 4 -> 4 (+ 0)
  Interval: chr11:6..9 -> chr11:6..10, length  5 : 4 -> 5 (+ 1)
  Interval: chr11:6..9 -> chr11:5..10, length  6 : 4 -> 6 (+ 2)
  Interval: chr11:6..9 -> chr11:5..11, length  7 : 4 -> 7 (+ 3)
  Interval: chr11:6..9 -> chr11:4..11, length  8 : 4 -> 8 (+ 4)
  Interval: chr11:6..9 -> chr11:4..12, length  9 : 4 -> 9 (+ 5)
  


  use Bio::DB::HTS::Faidx;
  
  my $hg_db = Bio::DB::HTS::Faidx->new( $hg_bgz_path );
  
  sub extract_sequence {
      my ($db, $interval) = @_;
  
      my $chrom_length = $db->length($interval->seq_id);
  
      my $trimmed_interval = $interval->clone;
      $trimmed_interval->start( List::Util::max( $interval->start, 1               ) );
      $trimmed_interval->end(   List::Util::min( $interval->end  , $chrom_length   ) );
  
      # Bio::DB::HTS::Faidx is 0-based for both start and end points
      my $seq = $db->get_sequence2_no_length(
          $trimmed_interval->seq_id,
          $trimmed_interval->start - 1,
          $trimmed_interval->end   - 1,
      );
  
      my $pad_upstream   = 'N' x List::Util::max( -($interval->start-1), 0 );
      my $pad_downstream = 'N' x List::Util::max( $interval->end - $chrom_length, 0 );
  
      return join '', $pad_upstream, $seq, $pad_downstream;
  }
  
  sub seq_info {
      my ($seq, $n) = @_;
      $n ||= 10;
      if( length $seq > $n ) {
          sprintf "%s...%s (length %d)", uc substr($seq, 0, $n), uc substr($seq, -$n), length $seq;
      } else {
          sprintf "%s (length %d)", uc $seq, length $seq;
      }
  }
  
  ####
  
  {
  
  say "Testing sequence extraction:";
  
  say "1 base: ",   seq_info
      extract_sequence( $hg_db,
          Interval->new( -seq_id => 'chr11',
              -start => 35_082_742 + 1,
              -end   => 35_082_742 + 1 ) );
  
  say "3 bases: ",  seq_info
      extract_sequence( $hg_db,
          Interval->new( -seq_id => 'chr11',
              -start => 1,
              -end   => 1 )->resize(3) );
  
  say "5 bases: ", seq_info
      extract_sequence( $hg_db,
          Interval->new( -seq_id => 'chr11',
              -start => $hg_db->length('chr11'),
              -end   => $hg_db->length('chr11') )->resize(5) );
  
  say "chr11 is of length ", $hg_db->length('chr11');
  say "chr11 bases: ", seq_info
      extract_sequence( $hg_db,
          Interval->new( -seq_id => 'chr11',
              -start => 1,
              -end   => $hg_db->length('chr11') )->resize( $hg_db->length('chr11') ) );
  }

B<STREAM (STDOUT)>:

  Testing sequence extraction:
  1 base: G (length 1)
  3 bases: NNN (length 3)
  5 bases: NNNNN (length 5)
  chr11 is of length 135086622
  chr11 bases: NNNNNNNNNN...NNNNNNNNNN (length 135086622)

B<RESULT>:

  1

Now we can use the same target interval that is used in the example notebook which recreates part of L<figure 1|https://www.nature.com/articles/s41592-021-01252-x/figures/1> from the Enformer paper.

  my $target_interval = Interval->new( -seq_id => 'chr11',
      -start => 35_082_742 +  1, # BioPerl is 1-based
      -end   => 35_197_430 );
  
  say "Target interval: $target_interval with length @{[ $target_interval->length ]}";
  
  die "Target interval is not $model_central_base_pairs_length bp long"
      unless $target_interval->length == $model_central_base_pairs_length;
  
  say "Target sequence is ", seq_info extract_sequence( $hg_db, $target_interval );
  
  
  say "";
  
  
  my $resized_interval = $target_interval->resize( $model_sequence_length );
  say "Resized interval: $resized_interval with length @{[ $resized_interval->length ]}";
  
  die "resize() is not working properly!" unless $resized_interval->length == $model_sequence_length;
  
  my $seq = extract_sequence( $hg_db, $resized_interval );
  
  say "Resized sequence is ", seq_info($seq);

B<STREAM (STDOUT)>:

  Target interval: chr11:35082743..35197430 with length 114688
  Target sequence is GGTGGCAGCC...ATCTCCTTTT (length 114688)
  
  Resized interval: chr11:34943479..35336694 with length 393216
  Resized sequence is ACTAGTTCTA...GGCCCAAATC (length 393216)

B<RESULT>:

  1

To prepare the input we have to one-hot encode this resized sequence and give it a dummy dimension at the end to indicate that it is is a batch with a single sequence. Then we can turn the PDL ndarray into a C<TFTensor> and pass it to our prediction ...

  my $sequence_one_hot = one_hot_dna( $seq )->dummy(-1);
  
  say $sequence_one_hot->info; undef;

B<STREAM (STDOUT)>:

  PDL: Float D [4,393216,1]


  use Devel::Timer;



( run in 0.682 second using v1.01-cache-2.11-cpan-cdf2f3d4e48 )