Advertisement
Guest User

Untitled

a guest
Apr 23rd, 2018
66
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.92 KB | None | 0 0
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import numpy as np
  5. import multiprocessing as mp
  6. from pylsl import StreamInlet, resolve_stream
  7. import scipy.signal as sig
  8. import random
  9. import sys
  10. import pygame as pg
  11. from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
  12. from sklearn.model_selection import train_test_split
  13. from preparing_p300 import preparing_p300 as pre_p300
  14. import matplotlib.pyplot as plt
  15. import pandas as pd
  16.  
  17. import sys
  18. import random
  19.  
  20. file_1 = 'p300_00_2017-11-28-Nov-1511861373_S1.csv' #wybrana liczba:2
  21. file_2 = 'p300_00_2017-11-28-Nov-1511861422_S2.csv' #wybrana liczba:2
  22. file_3 = 'p300_00_2017-11-28-Nov-1511860498.csv' #wybrana liczba: 1
  23. file_4 = 'p300_00_2017-11-28-Nov-1511860588.csv' #wybrana liczba: 1
  24. file_5 = 'p300_00_2017-11-28-Nov-1511862325.csv' #wybrana liczba: 3
  25. file_6 = 'p300_00_2017-11-28-Nov-1511862270.csv' #wybrana liczba: 3
  26.  
  27. mf_1 = []
  28.  
  29.  
  30. sys.path.insert(0, "/home/oskar/hack/LSL/LSL bindings")
  31.  
  32. def amp(sygnal):
  33. mins = np.abs(np.min(sygnal))
  34. maxs = np.max(sygnal)
  35. if ( mins <= maxs ):
  36. amp = maxs
  37. elif ( maxs <= mins ):
  38. amp = np.min(sygnal)
  39. return amp
  40.  
  41. def aamp(sygnal):
  42. aamp = np.abs(amp(sygnal))
  43. return aamp
  44.  
  45. #positive signal area
  46.  
  47. def par(sygnal):
  48. positive = [n for n in sygnal if n >= 0]
  49. par = np.sum(positive)
  50. return par
  51.  
  52. #negative signal area
  53.  
  54. def nar(sygnal):
  55. negative = [n for n in sygnal if n <= 0]
  56. nar = np.sum(negative)
  57. return nar
  58.  
  59. #total signal area
  60. def tar(sygnal):
  61. tar = par(sygnal) + nar(sygnal)
  62. return tar
  63.  
  64.  
  65.  
  66.  
  67. def extraction_morph(signal,nr=1):
  68. feat_array = np.array([par(signal),amp(signal), aamp(signal), nar(signal),tar(signal)])
  69. return feat_array
  70.  
  71.  
  72.  
  73. def filtering(self):
  74.  
  75. """This function filters the data"""
  76.  
  77. bplowcut = 0.5/(self.freq*0.5) #banpass
  78. bphighcut = 15/(self.freq*0.5) #bandpass
  79.  
  80. bslowcut = 49/(self.freq*0.5) #bandstop
  81. bshighcut = 51/(self.freq*0.5) #bandstop
  82.  
  83. [pb,pa] = sig.butter(N=4,Wn=[bslowcut,bshighcut],btype='bandstop')
  84. [sb,sa] = sig.butter(N=4,Wn=[bplowcut,bphighcut],btype='bandpass')
  85.  
  86. data = self.data
  87. result = pd.DataFrame()
  88.  
  89. for n in range(1,9,1):
  90.  
  91. filtered = sig.filtfilt(sb,sa,data['e%s'%str(n)])
  92. filtered = sig.filtfilt(pb,pa,filtered)
  93. result['e%s'%str(n)] = filtered
  94.  
  95. self.data = result
  96.  
  97. def mean_signal(data,freq=250,number_of_electrodes=4):
  98. e_sum = np.zeros((freq,))
  99. electrode_sum = np.zeros((freq,))
  100. il = 0
  101. for z in data.columns:
  102. for n in range(len(data[z])):
  103. if len(data[z][n:])>=freq:
  104. il += 1
  105. e_sum += np.array(data[z][n:n+freq],dtype=float)
  106. e_sum = e_sum/il
  107. electrode_sum += e_sum
  108. e_mean = electrode_sum/number_of_electrodes
  109. return e_mean
  110.  
  111.  
  112. for z in [file_1,file_2,file_3,file_4,file_5,file_6]:
  113. p300 = pre_p300(z,freq=200)
  114. p300.preparing_data()
  115. p300.filtering()
  116. p300.mean_signal()
  117. data = p300.signal_trig
  118. data = data[round(200*0.25):round(200*0.5)]
  119. for n in range(0,4):
  120. mf_1.append(extraction_morph(data.iloc[:,n]))
  121.  
  122. target = np.array([1,0,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,0,1,0,0,0,1])
  123. clf = LinearDiscriminantAnalysis()
  124. clf.fit(mf_1,target)
  125. #print(clf.score(X_test,y_test))
  126.  
  127. class CcaLive(object):
  128.  
  129. def __init__(self, sampling_rate=250, connect=True):
  130.  
  131. # Device parameters
  132.  
  133. self.sampling_rate = sampling_rate
  134. self.connect = connect
  135.  
  136.  
  137. self.__fs = 1./sampling_rate
  138. self.__t = np.arange(0.0, 1.0, self.__fs)
  139.  
  140.  
  141. self.streaming = mp.Event()
  142. self.terminate = mp.Event()
  143.  
  144. def initialize(self, state):
  145. self.prcs = mp.Process(target=self.split, args=(state,))
  146. self.prcs.daemon = True
  147. self.prcs.start()
  148.  
  149. def run_app(self):
  150.  
  151. if self.terminate.is_set():
  152. self.prcs.terminate()
  153. self.terminate.clear()
  154.  
  155. def split(self, state):
  156. self.increment = 0
  157. self.trigger = 0
  158. self.__pre_buffer = []
  159.  
  160. def handle_sample():
  161. ''' Save samples into table; single process '''
  162. while True:
  163. sample = inlet.pull_sample()[0]
  164.  
  165. self.increment += 1
  166.  
  167. self.trigger = state.value
  168. packet = sample[1:5]
  169.  
  170.  
  171. self.correlation.acquire_data(packet, self.trigger)
  172.  
  173. # Set termination
  174. if self.terminate.is_set():
  175. return False
  176.  
  177. # Board connection #
  178. self.correlation = CrossCorrelation(self.sampling_rate, 4)
  179. print("looking for an EEG stream...")
  180. streams = resolve_stream('type', 'EEG')
  181. print("Done!")
  182. inlet = StreamInlet(streams[0])
  183. handle_sample()
  184.  
  185. class CrossCorrelation(object):
  186. """CCA class; returns correlation value for each channel """
  187. def __init__(self, sampling_rate, channels_num):
  188. # self.packet_id = 0
  189. # self.all_packet = 0
  190. # self.stim_num = 0
  191. # self.trigger_num = 0
  192. # self.past_trigger = 0
  193. # self.sampling_rate = sampling_rate
  194. #self.signal_window = np.zeros(shape=(sampling_rate*4, channels_num))
  195. self.window_500 = [np.zeros(5)]
  196. # self.trigger_500 = []
  197. # print(self.window_500)
  198. # self.chunk = np.zeros(((4,250,4)))
  199. #self.channels = np.zeros(shape=(len(self.rs), 3), dtype=tuple)
  200. #self.p300_display = np.zeros(shape=(len(self.rs), 1), dtype=int)
  201.  
  202. # Check if table not empty #
  203.  
  204. def acquire_data(self, packet, trigger):
  205. # print(packet)
  206. # print(trigger)
  207. packet = np.append(packet, trigger)
  208. # packet = [np.zeros(5)]
  209. #print(str(packet))
  210. self.window_500.append(packet)
  211. # print(self.window_500)
  212. if len(self.window_500) >= 2500:
  213. # out_path = "data.csv"
  214. # with open(out_path, 'a') as out_file:
  215. # for i in range(len(self.window_500)):
  216. # # print(str(self.window_500))
  217. # out_file.write(str(self.window_500[i])+'\n')
  218. # # out_file.write('\n')
  219. #
  220. #
  221. # # out = ','.join(str(x) for y in x)
  222. # # out = ','.join(str(self.window_500[i]))
  223. ms = mean_signal(pd.DataFrame(self.window_500))
  224. # print(head(ms))
  225. classified = clf.predict(extraction_morph(ms).reshape(-1,1))
  226. print(classified)
  227. self.window_500 = []
  228.  
  229.  
  230.  
  231. # for n in range(250):
  232. # if self.trigger_500[n+1] != 0 and self.trigger_500[n] != 0:
  233. # self.chunk[(self.trigger_500[n+1])-1] = self.window_500[n:n+250]
  234. # for x in self.window_500[n:n+250]:
  235. # out = ','.join(str(y) for y in x)
  236. # out_path = "chunk"+str((self.trigger_500[n+1])-1)+".csv"
  237. # self.window_500 = self.window_500[250:500]
  238. # self.trigger_500 = self.trigger_500[250:500]
  239.  
  240. def filtering(self, packet):
  241. """ Push single sample into the list """
  242.  
  243. # Butter bandstop filter 49-51hz
  244. for i in range(8):
  245. signal = packet[:, i]
  246. lowcut = 49/(self.sampling_rate*0.5)
  247. highcut = 51/(self.sampling_rate*0.5)
  248. [b, a] = sig.butter(4, [lowcut, highcut], 'bandstop')
  249. packet[:, i] = sig.filtfilt(b, a, signal)
  250.  
  251. # Butter bandpass filter 3-49hz
  252. for i in range(8):
  253. signal = packet[:, i]
  254. lowcut = 3/(self.sampling_rate*0.5)
  255. highcut = 15/(self.sampling_rate*0.5)
  256. [b, a] = sig.butter(4, [lowcut, highcut], 'bandpass')
  257. packet[:, i] = sig.filtfilt(b, a, signal)
  258.  
  259. return packet
  260.  
  261. if __name__ == "__main__":
  262. test = CcaLive()
  263.  
  264. pg.init()
  265. screen = pg.display.set_mode((600, 600))
  266. i1 = pg.image.load('i1.png')
  267. i2 = pg.image.load('i2.png')
  268. i3 = pg.image.load('i3.png')
  269. i4 = pg.image.load('i4.png')
  270. i1p = pg.image.load('i1p.png')
  271. i2p = pg.image.load('i2p.png')
  272. i3p = pg.image.load('i3p.png')
  273. i4p = pg.image.load('i4p.png')
  274.  
  275. clock = pg.time.Clock()
  276. stim = random.randint(1, 4)
  277. state = mp.Value("i", stim)
  278. test.initialize(state)
  279. gap = random.randint(1, 100) + 100
  280. df = 0.0
  281. unlighted = True
  282. ##########################
  283.  
  284. while True:
  285. for event in pg.event.get():
  286. if event.type == pg.QUIT:
  287. pg.display.quit()
  288. pg.quit()
  289. sys.exit()
  290. elif event.type == pg.KEYDOWN:
  291. if event.key == pg.K_ESCAPE:
  292. pg.display.quit()
  293. pg.quit()
  294. sys.exit()
  295. df += clock.tick()
  296.  
  297. while df > gap and unlighted:
  298. state.value = 0
  299. screen.blit(i1,(0,0))
  300. screen.blit(i2,(0,300))
  301. screen.blit(i3,(300,0))
  302. screen.blit(i4,(300,300))
  303. pg.display.flip()
  304. unlighted = False
  305. gap = random.randint(1,100) + 100
  306.  
  307. while df > (100 + gap):
  308. df -= (100 + gap)
  309. unlighted = True
  310. new_stim = random.randint(1, 4)
  311. while stim == new_stim:
  312. new_stim = random.randint(1, 4)
  313. stim = new_stim
  314. state.value = stim
  315.  
  316. if stim == 1:
  317. screen.blit(i1p,(0,0))
  318. elif stim == 2:
  319. screen.blit(i2p,(0,300))
  320. elif stim == 3:
  321. screen.blit(i3p,(300,0))
  322. elif stim == 4:
  323. screen.blit(i4p,(300,300))
  324. pg.display.flip()
  325.  
  326. import gc
  327. collected = gc.collect()
  328. print(collected)
  329.  
  330. # Make sure it's dead.
  331. # if test.prcs.is_alive():
  332. # print("It was alive!")
  333. # test.prcs.terminate()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement