Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import matplotlib.pyplot as plt
- # Load Data
- from mne import viz
- from numpy.linalg import inv
- from toolbox.utils import chan4plot, downsample, cov_shrinkage, calc_confusion, calc_AUC
- datafile = "data/p3bci_data.npz"
- D = np.load(datafile, allow_pickle=True)
- data = D['data']
- onsets = D['onsets']
- timestamps = D['timestamps']
- flashseq = D['flashseq']
- targets = D['targets']
- sample_size = 205
- trials = 30
- subtrials = 10
- stimnum = 12
- epochs = np.zeros((3600,205,10))
- label = np.zeros((trials, subtrials, stimnum), dtype=int)
- #print(label)
- c = 0
- for i in range(trials):
- for j in range(subtrials):
- for k in range(stimnum):
- flash = np.abs(timestamps-onsets[i][j][k]).argmin()
- sample = data[flash:flash + sample_size]
- epochs[c,:,:] = sample
- c += 1
- # Assign labels
- for i in range(0, trials):
- for j in range(0, subtrials):
- for k in range(0, stimnum):
- label[i][j][k] = flashseq[i][j][k] == (targets[0, i] % 6) + 6 or flashseq[i][j][k] == targets[0, i] // 6
- label = label.flatten()
- targets = epochs[label==1,:,:]
- non_targets = epochs[label==0,:,:]
- # print("targets: ", len(targets), ", nontargets: ", len(non_targets))
- plt.xlabel('Samples')
- plt.ylabel('Amplitude')
- x = range(205)
- target_means = np.mean(targets, 0)
- non_target_means = np.mean(non_targets, 0)
- #plt.ion()
- # Targets
- #plt.plot(x, target_means[:, 1], color='b')
- # Non-Targets
- #plt.plot(x, non_target_means[:, 1], color='r')
- #plt.show()
- # Plot Head
- #erp = np.square(target_means[90,:] - non_target_means[90,:])
- #eeginfo = chan4plot()
- #sba = viz.plot_topomap(erp, eeginfo, show=False)
- # Reshape
- data = downsample(epochs, 10)
- flattened_data = data.reshape(3600,200)
- data_chunks = np.reshape(flattened_data, (5,720,200))
- label_chunks = np.reshape(label, (5, 720))
- # Fisher's Linear Discriminant Analysis
- def fda_train(data, label):
- posdata = data[label==1,:]
- negdata = data[label==0,:]
- posmean = np.mean(posdata, 0)
- negmean = np.mean(negdata, 0)
- spos = cov_shrinkage(posdata)
- sneg = cov_shrinkage(negdata)
- sw = spos + sneg
- fda_w = inv(sw)@np.transpose(posmean - negmean)
- fda_posmean = np.dot(fda_w, posmean)
- fda_negmean = np.dot(fda_w, negmean)
- fda_b = (fda_negmean+fda_posmean)/2
- return fda_w, fda_b
- # Asses performance with Receiver Operator Characteristics (ROC) curve
- def compute_ROC(fda_w, fda_b, testset, test_labels):
- testsize = len(testset)
- proj = np.zeros(testsize)
- # for j in range(testsize):
- # print('loop shape', testset[j, :].shape, fda_w.shape, 'result_shape', np.matmul(testset[j,:],fda_w).shape)
- proj = np.matmul(testset, fda_w) - fda_b
- proj = np.interp(proj, (proj.min(), proj.max()), (0, 1))
- pred = np.zeros(testsize)
- thr = np.linspace(0,1,100)
- TP_no = []
- FP_no = []
- print(proj)
- for b in range(len(thr)):
- pred[proj <= thr[b]] = 1
- TP, FP, FN, TN = calc_confusion(pred, test_labels, 1, 0)
- TP_no.append(TP)
- FP_no.append(FP)
- # print("Confusion Matrix: ", TP, FP, FN, TN)
- # print("TPs:", TP_no, "FPs:", FP_no)
- TP_curve = np.divide(TP_no, len(targets) / 5)
- FP_curve = np.divide(FP_no, len(non_targets) / 5)
- return TP_curve, FP_curve
- def plot_ROC(TP_curve, FP_curve):
- plt.title("ROC Curve")
- plt.xlabel("False Positives (FP)")
- plt.ylabel("True Positives (TP)")
- plt.plot(TP_curve, FP_curve)
- plt.show()
- for i in range(5):
- train_set = np.delete(data_chunks, i, axis=0)
- train_label_set = np.delete(label_chunks, i, axis=0)
- test_set = data_chunks[i]
- test_label_set = label_chunks[i]
- fda_w, fda_b = fda_train(train_set, train_label_set)
- ys = np.matmul(test_set, fda_w)
- TP_curve, FP_curve = compute_ROC(fda_w, fda_b, test_set, test_label_set)
- plot_ROC(TP_curve, FP_curve)
- roc = np.array([FP_curve, TP_curve])
- Auc = calc_AUC(roc)
- print(Auc)
- file_path = "fda.npy"
- np.save(file_path, fda_w, fda_b)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement