Advertisement
Guest User

Untitled

a guest
May 24th, 2019
292
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.00 KB | None | 0 0
  1. from pylsl import StreamInlet, resolve_stream, resolve_byprop
  2. import numpy as np
  3. import keras
  4. import utils
  5.  
  6. class Classifier_LSTM:
  7.  
  8. def __init__(self, output_directory, input_shape, nb_classes, verbose=False):
  9. self.output_directory = output_directory
  10. self.model = self.build_model(input_shape, nb_classes)
  11. if (verbose == True):
  12. self.model.summary()
  13. self.verbose = verbose
  14. self.model.save_weights(self.output_directory + 'model_init.hdf5')
  15.  
  16. def build_model(self, input_shape, nb_classes):
  17. padding = 'valid'
  18.  
  19. input_layer = keras.layers.Input(input_shape)
  20.  
  21. lstm = keras.layers.LSTM(units=100, return_sequences = True)(input_layer)
  22. time1 = keras.layers.TimeDistributed(keras.layers.Dense(50))(lstm)
  23. time1 = keras.layers.GlobalAveragePooling1D()(time1)
  24.  
  25. output_layer = keras.layers.Dense(units=1, activation='sigmoid')(time1)
  26.  
  27. #conv1 = keras.layers.Conv1D(filters=6,kernel_size=7,padding=padding,activation='sigmoid')(input_layer)
  28. # conv1 = keras.layers.Dropout(rate=0.5)(conv1)
  29. #conv1 = keras.layers.AveragePooling1D(pool_size=3)(conv1)
  30.  
  31. #conv2 = keras.layers.Conv1D(filters=12,kernel_size=7,padding=padding,activation='sigmoid')(conv1)
  32. #conv2 = keras.layers.Dropout(rate=0.5)(conv2)
  33. #conv2 = keras.layers.AveragePooling1D(pool_size=3)(conv2)
  34.  
  35. # flatten_layer = keras.layers.Flatten()(conv2)
  36.  
  37. # output_layer = keras.layers.Dense(units=nb_classes,activation='softmax')(flatten_layer)
  38.  
  39. model = keras.models.Model(inputs=input_layer, outputs=output_layer)
  40.  
  41. model.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(lr=0.01),
  42. metrics=['accuracy'])
  43.  
  44. # file_path = self.output_directory + 'best_model.hdf5'
  45.  
  46. # model_checkpoint = keras.callbacks.ModelCheckpoint(filepath=file_path, monitor='val_geometricmean',
  47. # save_best_only=True,mode='max')
  48.  
  49.  
  50. # self.callbacks = [model_checkpoint, early_stop]
  51.  
  52. return model
  53.  
  54. import time
  55.  
  56. # first resolve an EEG stream on the lab network
  57. print("looking for a stream...")
  58. stream_eeg = resolve_byprop('type', 'EEG')
  59. stream_ecg = resolve_byprop('type', 'ECG')
  60. # create a new inlet to read from the stream
  61. inlet_eeg = StreamInlet(stream_eeg[0], max_chunklen=12)
  62. inlet_ecg = StreamInlet(stream_ecg[0], max_chunklen=12)
  63. epoch_eeg=[]
  64.  
  65. import gc
  66.  
  67. #gc.collect()
  68.  
  69. import tensorflow as tf
  70. import time
  71.  
  72. tf.keras.backend.clear_session()
  73. tf.keras.backend.set_learning_phase(0)
  74.  
  75. time_modelload= time.clock()
  76.  
  77. classifier_eeg = keras.models.load_model('best_model_eeg.hdf5')
  78. classifier_eeg.compile(loss='binary_crossentropy',optimizer=keras.optimizers.Adam(lr=0.01), metrics=['accuracy'])
  79.  
  80. classifier_ecg = keras.models.load_model('best_model_ecg.hdf5')
  81. classifier_ecg.compile(loss='binary_crossentropy',optimizer=keras.optimizers.Adam(lr=0.01), metrics=['accuracy'])
  82.  
  83. #gc.collect()
  84. #time.sleep(10)
  85.  
  86. eeg_inarow_counts=0
  87.  
  88. classifs_time=[]
  89.  
  90. epoch_length=3
  91. info = inlet_eeg.info()
  92. description = info.desc()
  93.  
  94.  
  95. buffer_len=5
  96.  
  97. print("desc:", description)
  98. fs= int(750/3)
  99.  
  100. eeg_buffer = np.zeros((int(fs * 3),4))
  101. ecg_buffer = np.zeros((int(fs * 3),1))
  102. n_win_test = int(np.floor((buffer_len-3 / epoch_length +1)))
  103.  
  104.  
  105. while True:
  106. # get a new sample (you can also omit the timestamp part if you're not
  107. # interested in it)
  108. sample_eeg, timestamp_eeg = inlet_eeg.pull_chunk(timeout=1)
  109. sample_ecg, timestamp_ecg = inlet_ecg.pull_chunk(timeout=1)
  110. # print("sample eeg:", timestamp_eeg)
  111. # print("sample ecg:", timestamp_ecg)
  112. sample_eeg=np.array(sample_eeg)[:,0:4]
  113. sample_ecg = np.array(sample_ecg)[:,0]
  114. # print("ECG timestamp:", timestamp_ecg, "EEG timestamp:", timestamp_eeg)
  115. # print()
  116. # print("ECG epoch made!")
  117.  
  118. #print(sample_eeg.shape)
  119. eeg_buffer, filter_state = utils.update_buffer(eeg_buffer, sample_eeg, filter_state=None)
  120. epoch_eeg=utils.get_last_data(eeg_buffer, 3*fs)
  121. print("eeg shape", epoch_eeg.shape)
  122.  
  123. epoch_eeg=np.array(epoch_eeg).reshape(1,750,4)
  124. # print("EEG epoch made!")
  125.  
  126. predicted_eeg = classifier_eeg.predict(epoch_eeg)>0.5
  127. print("Ictal according to EEG ?: ", predicted_eeg)
  128.  
  129. if(predicted_eeg):
  130. eeg_inarow_counts=eeg_inarow_counts+1
  131. else:
  132. eeg_inarow_counts=0
  133.  
  134. ecg_buffer, filter_state = utils.update_buffer(ecg_buffer, sample_ecg, filter_state=None)
  135. epoch_ecg=utils.get_last_data(ecg_buffer, 3*fs)
  136. print("ecg shape", epoch_ecg.shape)
  137.  
  138. epoch_ecg=np.array(epoch_ecg).reshape(1,750,1)
  139. # print("EEG epoch made!")
  140.  
  141. predicted_ecg = classifier_ecg.predict(epoch_ecg)>0.5
  142. print("Ictal according to ECG ?: ", predicted_ecg)
  143.  
  144. if(predicted_ecg):
  145. ecg_inarow_counts=ecg_inarow_counts+1
  146. else:
  147. ecg_inarow_counts=0
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement