Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- import numpy as np
- import multiprocessing as mp
- from pylsl import StreamInlet, resolve_stream
- import scipy.signal as sig
- import random
- import sys
- import pygame as pg
- from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
- from sklearn.model_selection import train_test_split
- from preparing_p300 import preparing_p300 as pre_p300
- import matplotlib.pyplot as plt
- import pandas as pd
- import sys
- import random
- file_1 = 'p300_00_2017-11-28-Nov-1511861373_S1.csv' #wybrana liczba:2
- file_2 = 'p300_00_2017-11-28-Nov-1511861422_S2.csv' #wybrana liczba:2
- file_3 = 'p300_00_2017-11-28-Nov-1511860498.csv' #wybrana liczba: 1
- file_4 = 'p300_00_2017-11-28-Nov-1511860588.csv' #wybrana liczba: 1
- file_5 = 'p300_00_2017-11-28-Nov-1511862325.csv' #wybrana liczba: 3
- file_6 = 'p300_00_2017-11-28-Nov-1511862270.csv' #wybrana liczba: 3
- mf_1 = []
- sys.path.insert(0, "/home/oskar/hack/LSL/LSL bindings")
- def amp(sygnal):
- mins = np.abs(np.min(sygnal))
- maxs = np.max(sygnal)
- if ( mins <= maxs ):
- amp = maxs
- elif ( maxs <= mins ):
- amp = np.min(sygnal)
- return amp
- def aamp(sygnal):
- aamp = np.abs(amp(sygnal))
- return aamp
- #positive signal area
- def par(sygnal):
- positive = [n for n in sygnal if n >= 0]
- par = np.sum(positive)
- return par
- #negative signal area
- def nar(sygnal):
- negative = [n for n in sygnal if n <= 0]
- nar = np.sum(negative)
- return nar
- #total signal area
- def tar(sygnal):
- tar = par(sygnal) + nar(sygnal)
- return tar
- def extraction_morph(signal,nr=1):
- feat_array = np.array([par(signal),amp(signal), aamp(signal), nar(signal),tar(signal)])
- return feat_array
- def filtering(self):
- """This function filters the data"""
- bplowcut = 0.5/(self.freq*0.5) #banpass
- bphighcut = 15/(self.freq*0.5) #bandpass
- bslowcut = 49/(self.freq*0.5) #bandstop
- bshighcut = 51/(self.freq*0.5) #bandstop
- [pb,pa] = sig.butter(N=4,Wn=[bslowcut,bshighcut],btype='bandstop')
- [sb,sa] = sig.butter(N=4,Wn=[bplowcut,bphighcut],btype='bandpass')
- data = self.data
- result = pd.DataFrame()
- for n in range(1,9,1):
- filtered = sig.filtfilt(sb,sa,data['e%s'%str(n)])
- filtered = sig.filtfilt(pb,pa,filtered)
- result['e%s'%str(n)] = filtered
- self.data = result
- def mean_signal(data,freq=250,number_of_electrodes=4):
- e_sum = np.zeros((freq,))
- electrode_sum = np.zeros((freq,))
- il = 0
- for z in data.columns:
- for n in range(len(data[z])):
- if len(data[z][n:])>=freq:
- il += 1
- e_sum += np.array(data[z][n:n+freq],dtype=float)
- e_sum = e_sum/il
- electrode_sum += e_sum
- e_mean = electrode_sum/number_of_electrodes
- return e_mean
- for z in [file_1,file_2,file_3,file_4,file_5,file_6]:
- p300 = pre_p300(z,freq=200)
- p300.preparing_data()
- p300.filtering()
- p300.mean_signal()
- data = p300.signal_trig
- data = data[round(200*0.25):round(200*0.5)]
- for n in range(0,4):
- mf_1.append(extraction_morph(data.iloc[:,n]))
- target = np.array([1,0,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,0,1,0,0,0,1])
- clf = LinearDiscriminantAnalysis()
- clf.fit(mf_1,target)
- #print(clf.score(X_test,y_test))
- class CcaLive(object):
- def __init__(self, sampling_rate=250, connect=True):
- # Device parameters
- self.sampling_rate = sampling_rate
- self.connect = connect
- self.__fs = 1./sampling_rate
- self.__t = np.arange(0.0, 1.0, self.__fs)
- self.streaming = mp.Event()
- self.terminate = mp.Event()
- def initialize(self, state):
- self.prcs = mp.Process(target=self.split, args=(state,))
- self.prcs.daemon = True
- self.prcs.start()
- def run_app(self):
- if self.terminate.is_set():
- self.prcs.terminate()
- self.terminate.clear()
- def split(self, state):
- self.increment = 0
- self.trigger = 0
- self.__pre_buffer = []
- def handle_sample():
- ''' Save samples into table; single process '''
- while True:
- sample = inlet.pull_sample()[0]
- self.increment += 1
- self.trigger = state.value
- packet = sample[1:5]
- self.correlation.acquire_data(packet, self.trigger)
- # Set termination
- if self.terminate.is_set():
- return False
- # Board connection #
- self.correlation = CrossCorrelation(self.sampling_rate, 4)
- print("looking for an EEG stream...")
- streams = resolve_stream('type', 'EEG')
- print("Done!")
- inlet = StreamInlet(streams[0])
- handle_sample()
- class CrossCorrelation(object):
- """CCA class; returns correlation value for each channel """
- def __init__(self, sampling_rate, channels_num):
- # self.packet_id = 0
- # self.all_packet = 0
- # self.stim_num = 0
- # self.trigger_num = 0
- # self.past_trigger = 0
- # self.sampling_rate = sampling_rate
- #self.signal_window = np.zeros(shape=(sampling_rate*4, channels_num))
- self.window_500 = [np.zeros(5)]
- # self.trigger_500 = []
- # print(self.window_500)
- # self.chunk = np.zeros(((4,250,4)))
- #self.channels = np.zeros(shape=(len(self.rs), 3), dtype=tuple)
- #self.p300_display = np.zeros(shape=(len(self.rs), 1), dtype=int)
- # Check if table not empty #
- def acquire_data(self, packet, trigger):
- # print(packet)
- # print(trigger)
- packet = np.append(packet, trigger)
- # packet = [np.zeros(5)]
- #print(str(packet))
- self.window_500.append(packet)
- # print(self.window_500)
- if len(self.window_500) >= 2500:
- # out_path = "data.csv"
- # with open(out_path, 'a') as out_file:
- # for i in range(len(self.window_500)):
- # # print(str(self.window_500))
- # out_file.write(str(self.window_500[i])+'\n')
- # # out_file.write('\n')
- #
- #
- # # out = ','.join(str(x) for y in x)
- # # out = ','.join(str(self.window_500[i]))
- ms = mean_signal(pd.DataFrame(self.window_500))
- # print(head(ms))
- classified = clf.predict(extraction_morph(ms).reshape(-1,1))
- print(classified)
- self.window_500 = []
- # for n in range(250):
- # if self.trigger_500[n+1] != 0 and self.trigger_500[n] != 0:
- # self.chunk[(self.trigger_500[n+1])-1] = self.window_500[n:n+250]
- # for x in self.window_500[n:n+250]:
- # out = ','.join(str(y) for y in x)
- # out_path = "chunk"+str((self.trigger_500[n+1])-1)+".csv"
- # self.window_500 = self.window_500[250:500]
- # self.trigger_500 = self.trigger_500[250:500]
- def filtering(self, packet):
- """ Push single sample into the list """
- # Butter bandstop filter 49-51hz
- for i in range(8):
- signal = packet[:, i]
- lowcut = 49/(self.sampling_rate*0.5)
- highcut = 51/(self.sampling_rate*0.5)
- [b, a] = sig.butter(4, [lowcut, highcut], 'bandstop')
- packet[:, i] = sig.filtfilt(b, a, signal)
- # Butter bandpass filter 3-49hz
- for i in range(8):
- signal = packet[:, i]
- lowcut = 3/(self.sampling_rate*0.5)
- highcut = 15/(self.sampling_rate*0.5)
- [b, a] = sig.butter(4, [lowcut, highcut], 'bandpass')
- packet[:, i] = sig.filtfilt(b, a, signal)
- return packet
- if __name__ == "__main__":
- test = CcaLive()
- pg.init()
- screen = pg.display.set_mode((600, 600))
- i1 = pg.image.load('i1.png')
- i2 = pg.image.load('i2.png')
- i3 = pg.image.load('i3.png')
- i4 = pg.image.load('i4.png')
- i1p = pg.image.load('i1p.png')
- i2p = pg.image.load('i2p.png')
- i3p = pg.image.load('i3p.png')
- i4p = pg.image.load('i4p.png')
- clock = pg.time.Clock()
- stim = random.randint(1, 4)
- state = mp.Value("i", stim)
- test.initialize(state)
- gap = random.randint(1, 100) + 100
- df = 0.0
- unlighted = True
- ##########################
- while True:
- for event in pg.event.get():
- if event.type == pg.QUIT:
- pg.display.quit()
- pg.quit()
- sys.exit()
- elif event.type == pg.KEYDOWN:
- if event.key == pg.K_ESCAPE:
- pg.display.quit()
- pg.quit()
- sys.exit()
- df += clock.tick()
- while df > gap and unlighted:
- state.value = 0
- screen.blit(i1,(0,0))
- screen.blit(i2,(0,300))
- screen.blit(i3,(300,0))
- screen.blit(i4,(300,300))
- pg.display.flip()
- unlighted = False
- gap = random.randint(1,100) + 100
- while df > (100 + gap):
- df -= (100 + gap)
- unlighted = True
- new_stim = random.randint(1, 4)
- while stim == new_stim:
- new_stim = random.randint(1, 4)
- stim = new_stim
- state.value = stim
- if stim == 1:
- screen.blit(i1p,(0,0))
- elif stim == 2:
- screen.blit(i2p,(0,300))
- elif stim == 3:
- screen.blit(i3p,(300,0))
- elif stim == 4:
- screen.blit(i4p,(300,300))
- pg.display.flip()
- import gc
- collected = gc.collect()
- print(collected)
- # Make sure it's dead.
- # if test.prcs.is_alive():
- # print("It was alive!")
- # test.prcs.terminate()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement