Advertisement
bversteeg

Keras recurrent convolutional BLSTM neural network

Jul 8th, 2016
635
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.96 KB | None | 0 0
  1. import os
  2. os.environ['THEANO_FLAGS'] = ",".join('%s=%s' % (key, value) for key, value in {
  3.     'device': 'cpu',
  4.     'mode': 'FAST_RUN',
  5.     'floatX': 'float32',
  6.     #'exception_verbosity': 'high',
  7.     'optimizer': 'fast_compile',
  8. }.items())
  9.  
  10. import random
  11. from keras.utils.np_utils import *
  12. from keras.models import *
  13. from keras.layers import *
  14. import numpy as np; np.random.seed(1337)
  15.  
  16.  
  17. ########### CONSTANTS / CONFIG ###########
  18. NUM_EPOCHS = 300
  19. ENCODE_MAPPER = {
  20.     'a': (1, 0, 0, 0),
  21.     'c': (0, 1, 0, 0),
  22.     'g': (0, 0, 1, 0),
  23.     't': (0, 0, 0, 1),
  24.     # 'n': (1, 1, 1, 1),
  25. }
  26. NUM_CLASSES = 2
  27. PRINT_WIDTH = 200
  28. MATCH_SEQ = 'ac'
  29.  
  30. ONE_HOT_DIMENSION = len(list(ENCODE_MAPPER.values())[0])
  31. ALPHABET = list(ENCODE_MAPPER.keys())
  32. ALPHABET_SIZE = len(ALPHABET)
  33. DECODE_MAPPER = {
  34.     one_hot: nucleotide
  35.     for nucleotide, one_hot in ENCODE_MAPPER.items()
  36. }
  37.  
  38.  
  39. class Main(object):
  40.     def __init__(self):
  41.         self.model = None
  42.         self.build_model()
  43.         self.perform_training()
  44.         self.perform_testing()
  45.  
  46.     def build_model(self):
  47.         sequence = Input(shape=(None, ONE_HOT_DIMENSION), dtype='float32')
  48.  
  49.         # convolution = Convolution1D(filter_length=6, nb_filter=10)(sequence)
  50.         # max_pooling = MaxPooling1D(pool_length=2)(convolution)
  51.         # dropout = Dropout(0.2)(max_pooling)
  52.  
  53.         dropout = Dropout(0.2)(sequence)
  54.  
  55.         # bidirectional LSTM
  56.         forward_lstm = LSTM(
  57.             output_dim=50, init='uniform', inner_init='uniform', forget_bias_init='one', return_sequences=True,
  58.             activation='tanh', inner_activation='sigmoid',
  59.         )(dropout)
  60.         backward_lstm = LSTM(
  61.             output_dim=50, init='uniform', inner_init='uniform', forget_bias_init='one', return_sequences=True,
  62.             activation='tanh', inner_activation='sigmoid', go_backwards=True,
  63.         )(dropout)
  64.         blstm = merge([forward_lstm, backward_lstm], mode='concat', concat_axis=-1)
  65.  
  66.         dense = TimeDistributed(Dense(NUM_CLASSES))(blstm)
  67.  
  68.         self.model = Model(input=sequence, output=dense)
  69.         print 'Compiling model...'
  70.         self.model.compile(
  71.             loss='binary_crossentropy',
  72.             optimizer='adam',
  73.             metrics=['accuracy']
  74.         )
  75.  
  76.     def perform_training(self):
  77.         print 'Training...'
  78.         sequence_list, y_list = self.generate_sequences((100, 150), 10 )
  79.         X_list = [
  80.             self.encode_sequence(seq)
  81.             for seq in sequence_list
  82.         ]
  83.         for sample_i, (X, y) in enumerate(zip(X_list, y_list)):
  84.             print '\t X.shape=%s  y.shape=%s' % (X.shape, y.shape)
  85.             self.model.fit(
  86.                 X, y,
  87.                 verbose=0,
  88.                 batch_size=1,
  89.                 nb_epoch=NUM_EPOCHS,
  90.                 #callbacks=[ EarlyStopping(monitor='val_loss', patience=3, verbose=1) ]
  91.             )
  92.  
  93.     def perform_testing(self):
  94.         print 'Testing...'
  95.         sequence_list, y_list = self.generate_sequences((100, 150), 10)
  96.         X_list = [
  97.             self.encode_sequence(seq)
  98.             for seq in sequence_list
  99.         ]
  100.         for sample_i, (X, y) in enumerate(zip(X_list, y_list)):
  101.             print '\t X.shape=%s   y.shape=%s' % (X.shape, y.shape)
  102.             y_predicted = np.round(np.array(
  103.                 self.model.predict(X, batch_size=1)
  104.             ))
  105.             print '\t y_predicted.shape=%s' % str(y.shape)
  106.             print 'Test accuracy:', accuracy(y, y_predicted)
  107.  
  108.             print '\nSample #%d' % sample_i
  109.             self.print_seq_labels(
  110.                 sequence_list[sample_i],
  111.                 "".join(
  112.                     '2' if sum(l) == 2 else '0' if sum(l) == 0 else '1' if l[1] else '-'
  113.                     for l in y[0]
  114.                 ),
  115.                 "".join(
  116.                     '2' if sum(l) == 2 else '0' if sum(l) == 0 else '1' if l[1] else '-'
  117.                     for l in y_predicted[0]
  118.                 )
  119.             )
  120.  
  121.     def generate_sequences(self, length_range, num_sequences):
  122.         print 'Generating data...'
  123.         sequence_list, y_list = [], []
  124.  
  125.         for sequence_i in range(num_sequences):
  126.             length = random.randint(*length_range)
  127.             # Generate a random DNA sequence, while ensuring that the gene occurs in the sequence at least once
  128.             sequence = "".join(
  129.                 random.choice(ALPHABET) for i in range(length)
  130.             )
  131.             sequence_list.append(sequence)
  132.             y = np.zeros((len(sequence), 2), dtype=int)
  133.  
  134.             # Find all the positions that the gene matches on the sequence and label these regions
  135.             # matches = [match.start() for match in re.finditer(self.target_gene, sequence)]
  136.             matches = self.find_all_substrings(MATCH_SEQ, sequence)
  137.             print 'Labeled %d genes...' % len(matches)
  138.  
  139.             y[:, 0] = 1
  140.             for match_index in matches:
  141.                 y[:, 0][match_index : match_index + len(MATCH_SEQ)] = 0
  142.                 y[:, 1][match_index : match_index + len(MATCH_SEQ)] = 1
  143.             y_list.append(np.array([y]))
  144.         return sequence_list, y_list
  145.  
  146.     @staticmethod
  147.     def find_all_substrings(sub, string):
  148.         index = 0
  149.         matches = []
  150.         try:
  151.             while True:
  152.                 index = string.index(sub, index + 1)
  153.                 matches.append(index)
  154.         except ValueError:
  155.             pass
  156.         return matches
  157.  
  158.     @staticmethod
  159.     def print_seq_labels(*row_list):
  160.         assert len(set(len(row) for row in row_list)) == 1, 'not all rows are of equal length'
  161.         length = len(row_list[0])
  162.         for i in range(0, length, PRINT_WIDTH):
  163.             for row in row_list:
  164.                 print "".join(map(str, row[i:i + PRINT_WIDTH]))
  165.             print
  166.  
  167.     @staticmethod
  168.     def encode_sequence(sequence):
  169.         return np.array([[
  170.             ENCODE_MAPPER[nucleotide]
  171.             for nucleotide in sequence
  172.         ]])
  173.  
  174.  
  175. if __name__ == '__main__':
  176.     Main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement