Advertisement
Guest User

Untitled

a guest
Feb 18th, 2020
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.07 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3.  
  4. # Load Data
  5. from mne import viz
  6. from numpy.linalg import inv
  7.  
  8. from toolbox.utils import chan4plot, downsample, cov_shrinkage, calc_confusion, calc_AUC
  9.  
  10. datafile = "data/p3bci_data.npz"
  11. D = np.load(datafile, allow_pickle=True)
  12. data = D['data']
  13. onsets = D['onsets']
  14. timestamps = D['timestamps']
  15. flashseq = D['flashseq']
  16. targets = D['targets']
  17.  
  18. sample_size = 205
  19.  
  20. trials = 30
  21. subtrials = 10
  22. stimnum = 12
  23.  
  24. epochs = np.zeros((3600,205,10))
  25. label = np.zeros((trials, subtrials, stimnum), dtype=int)
  26. #print(label)
  27.  
  28. c = 0
  29. for i in range(trials):
  30. for j in range(subtrials):
  31. for k in range(stimnum):
  32. flash = np.abs(timestamps-onsets[i][j][k]).argmin()
  33. sample = data[flash:flash + sample_size]
  34. epochs[c,:,:] = sample
  35. c += 1
  36.  
  37. # Assign labels
  38. for i in range(0, trials):
  39. for j in range(0, subtrials):
  40. for k in range(0, stimnum):
  41. label[i][j][k] = flashseq[i][j][k] == (targets[0, i] % 6) + 6 or flashseq[i][j][k] == targets[0, i] // 6
  42.  
  43. label = label.flatten()
  44.  
  45. targets = epochs[label==1,:,:]
  46. non_targets = epochs[label==0,:,:]
  47. # print("targets: ", len(targets), ", nontargets: ", len(non_targets))
  48.  
  49. plt.xlabel('Samples')
  50. plt.ylabel('Amplitude')
  51.  
  52. x = range(205)
  53.  
  54. target_means = np.mean(targets, 0)
  55. non_target_means = np.mean(non_targets, 0)
  56.  
  57. #plt.ion()
  58. # Targets
  59. #plt.plot(x, target_means[:, 1], color='b')
  60. # Non-Targets
  61. #plt.plot(x, non_target_means[:, 1], color='r')
  62.  
  63. #plt.show()
  64.  
  65. # Plot Head
  66. #erp = np.square(target_means[90,:] - non_target_means[90,:])
  67.  
  68. #eeginfo = chan4plot()
  69. #sba = viz.plot_topomap(erp, eeginfo, show=False)
  70.  
  71. # Reshape
  72. data = downsample(epochs, 10)
  73. flattened_data = data.reshape(3600,200)
  74.  
  75. data_chunks = np.reshape(flattened_data, (5,720,200))
  76. label_chunks = np.reshape(label, (5, 720))
  77.  
  78.  
  79. # Fisher's Linear Discriminant Analysis
  80. def fda_train(data, label):
  81. posdata = data[label==1,:]
  82. negdata = data[label==0,:]
  83.  
  84. posmean = np.mean(posdata, 0)
  85. negmean = np.mean(negdata, 0)
  86.  
  87. spos = cov_shrinkage(posdata)
  88. sneg = cov_shrinkage(negdata)
  89.  
  90. sw = spos + sneg
  91.  
  92. fda_w = inv(sw)@np.transpose(posmean - negmean)
  93.  
  94. fda_posmean = np.dot(fda_w, posmean)
  95.  
  96. fda_negmean = np.dot(fda_w, negmean)
  97.  
  98. fda_b = (fda_negmean+fda_posmean)/2
  99.  
  100. return fda_w, fda_b
  101.  
  102.  
  103. # Asses performance with Receiver Operator Characteristics (ROC) curve
  104. def compute_ROC(fda_w, fda_b, testset, test_labels):
  105. testsize = len(testset)
  106. proj = np.zeros(testsize)
  107.  
  108.  
  109. # for j in range(testsize):
  110. # print('loop shape', testset[j, :].shape, fda_w.shape, 'result_shape', np.matmul(testset[j,:],fda_w).shape)
  111. proj = np.matmul(testset, fda_w) - fda_b
  112.  
  113. proj = np.interp(proj, (proj.min(), proj.max()), (0, 1))
  114. pred = np.zeros(testsize)
  115. thr = np.linspace(0,1,100)
  116.  
  117. TP_no = []
  118. FP_no = []
  119. print(proj)
  120. for b in range(len(thr)):
  121. pred[proj <= thr[b]] = 1
  122. TP, FP, FN, TN = calc_confusion(pred, test_labels, 1, 0)
  123. TP_no.append(TP)
  124. FP_no.append(FP)
  125. # print("Confusion Matrix: ", TP, FP, FN, TN)
  126. # print("TPs:", TP_no, "FPs:", FP_no)
  127. TP_curve = np.divide(TP_no, len(targets) / 5)
  128. FP_curve = np.divide(FP_no, len(non_targets) / 5)
  129. return TP_curve, FP_curve
  130.  
  131.  
  132. def plot_ROC(TP_curve, FP_curve):
  133. plt.title("ROC Curve")
  134. plt.xlabel("False Positives (FP)")
  135. plt.ylabel("True Positives (TP)")
  136. plt.plot(TP_curve, FP_curve)
  137. plt.show()
  138.  
  139.  
  140. for i in range(5):
  141. train_set = np.delete(data_chunks, i, axis=0)
  142. train_label_set = np.delete(label_chunks, i, axis=0)
  143. test_set = data_chunks[i]
  144. test_label_set = label_chunks[i]
  145.  
  146. fda_w, fda_b = fda_train(train_set, train_label_set)
  147. ys = np.matmul(test_set, fda_w)
  148.  
  149. TP_curve, FP_curve = compute_ROC(fda_w, fda_b, test_set, test_label_set)
  150. plot_ROC(TP_curve, FP_curve)
  151. roc = np.array([FP_curve, TP_curve])
  152. Auc = calc_AUC(roc)
  153. print(Auc)
  154.  
  155. file_path = "fda.npy"
  156. np.save(file_path, fda_w, fda_b)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement