Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from keras.datasets import mnist
- from brian2 import *
- import brian2.numpy_ as np
- import matplotlib.pyplot as plt
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
- # Load MNIST dataset
- (X_train, y_train), (X_test, y_test) = mnist.load_data()
- # # Simplified classification (0, 1, and 8)
- # X_train = X_train[(y_train == 1) | (y_train == 0) | (y_train == 8)]
- # y_train = y_train[(y_train == 1) | (y_train == 0) | (y_train == 8)]
- # X_test = X_test[(y_test == 1) | (y_test == 0) | (y_test == 8)]
- # y_test = y_test[(y_test == 1) | (y_test == 0) | (y_test == 8)]
- # Pixel intensity to firing rate (255 becomes ~63Hz)
- X_train = X_train / 4
- X_test = X_test / 4
- # Flatten the images
- X_train = X_train.reshape(X_train.shape[0], -1)
- X_test = X_test.reshape(X_test.shape[0], -1)
- # Define SNN parameters
- n_input = 28*28 # Input layer
- n_e = 100 # Excitatory neurons
- n_i = n_e # Inhibitory neurons
- v_rest_e = -60.*mV # Membrane potential
- v_reset_e = -65.*mV
- v_thresh_e = -52.*mV
- v_rest_i = -60.*mV
- v_reset_i = -45.*mV
- v_thresh_i = -40.*mV
- taupre = 20*ms
- taupost = taupre
- gmax = .05
- dApre = .01
- dApost = -dApre * taupre / taupost * 1.05
- dApost *= gmax
- dApre *= gmax
- # Define STDP equations
- stdp_eqs = '''
- w : 1
- lr : 1 (shared)
- dApre/dt = -Apre / taupre : 1 (event-driven)
- dApost/dt = -Apost / taupost : 1 (event-driven)'''
- # Pre-synaptic spike update
- stdp_pre = '''
- ge += w
- Apre += dApre
- w = clip(w + lr*Apost, 0, gmax)'''
- # Post-synaptic spike update
- stdp_post = '''
- Apost += dApost
- w = clip(w + lr*Apre, 0, gmax)'''
- class Model():
- def __init__(self):
- # Input Poisson Group
- self.PG = PoissonGroup(n_input, rates=np.zeros(n_input)*Hz, name='PG')
- # Excitatory Neuron Group
- self.EG = NeuronGroup(n_e, '''
- dv/dt = (ge*(0*mV-v) + gi*(-100*mV-v) + (v_rest_e-v)) / (100*ms) : volt
- dge/dt = -ge / (5*ms) : 1
- dgi/dt = -gi / (10*ms) : 1
- ''',
- threshold='v>v_thresh_e', refractory=5*ms, reset='v=v_reset_e', method='euler', name='EG')
- self.EG.v = v_rest_e - 20.*mV
- # Inhibitory Neuron Group
- self.IG = NeuronGroup(n_i, '''
- dv/dt = (ge*(0*mV-v) + (v_rest_i-v)) / (10*ms) : volt
- dge/dt = -ge / (5*ms) : 1
- ''',
- threshold='v>v_thresh_i', refractory=2*ms, reset='v=v_reset_i', method='euler', name='IG')
- self.IG.v = v_rest_i - 20.*mV
- # Synapses between Poisson Group and Excitatory Neurons
- self.S1 = Synapses(self.PG, self.EG, stdp_eqs, on_pre=stdp_pre, on_post=stdp_post, method='euler', name='S1')
- self.S1.connect()
- self.S1.w = 'rand()*gmax' # Random weights initialization
- self.S1.lr = 1 # Enable STDP
- # Synapses between Excitatory and Inhibitory Neurons
- self.S2 = Synapses(self.EG, self.IG, 'w : 1', on_pre='ge += w', name='S2')
- self.S2.connect(j='i')
- self.S2.delay = 'rand()*10*ms'
- self.S2.w = 3 # Very strong fixed weights
- # Synapses between Inhibitory and Excitatory Neurons
- self.S3 = Synapses(self.IG, self.EG, 'w : 1', on_pre='gi += w', name='S3')
- self.S3.connect(condition='i!=j')
- self.S3.delay = 'rand()*5*ms'
- self.S3.w = .03 # Balanced weights
- # Initialize Brian2 Network
- self.net = Network(self.PG, self.EG, self.IG, self.S1, self.S2, self.S3)
- self.net.run(0*second)
- def train(self, X, epoch=1):
- self.S1.lr = 1 # Enable STDP
- for ep in range(epoch):
- for idx in range(len(X)):
- # Active mode
- self.PG.rates = X[idx].ravel()*Hz
- self.net.run(0.35*second)
- # Passive mode
- self.PG.rates = np.zeros(n_input)*Hz
- self.net.run(0.15*second)
- def evaluate(self, X):
- self.S1.lr = 0 # Disable STDP
- features = []
- for idx in range(len(X)):
- # Rate monitor to count spikes
- mon = SpikeMonitor(self.EG, name='RM')
- self.net.add(mon)
- # Active mode
- self.PG.rates = X[idx].ravel()*Hz
- self.net.run(0.35*second)
- # Spikes per neuron for each image
- features.append(np.array(mon.count, dtype=int8))
- # Passive mode
- self.PG.rates = np.zeros(n_input)*Hz
- self.net.run(0.15*second)
- self.net.remove(self.net['RM'])
- return features
- import seaborn as sns
- # Test the SNN model with evaluation metrics and confusion matrix plotting
- def test_snn(train_items=500, assign_items=100, eval_items=100):
- seed(0)
- model = Model()
- model.train(X_train[:train_items], epoch=1)
- train_features = model.evaluate(X_train[:assign_items])
- test_features = model.evaluate(X_test[:eval_items])
- # Perform classification using a simple thresholding method
- threshold = 10 # Example threshold value
- train_predictions = [1 if np.sum(f) > threshold else 0 for f in train_features]
- test_predictions = [1 if np.sum(f) > threshold else 0 for f in test_features]
- # # Perform classification using argmax to determine the predicted class
- # train_predictions = np.argmax(train_features, axis=1)
- # test_predictions = np.argmax(test_features, axis=1)
- # Calculate evaluation metrics
- train_accuracy = accuracy_score(y_train[:assign_items], train_predictions)
- test_accuracy = accuracy_score(y_test[:eval_items], test_predictions)
- train_precision = precision_score(y_train[:assign_items], train_predictions, average='weighted')
- test_precision = precision_score(y_test[:eval_items], test_predictions, average='weighted')
- train_recall = recall_score(y_train[:assign_items], train_predictions, average='weighted')
- test_recall = recall_score(y_test[:eval_items], test_predictions, average='weighted')
- train_f1 = f1_score(y_train[:assign_items], train_predictions, average='weighted')
- test_f1 = f1_score(y_test[:eval_items], test_predictions, average='weighted')
- train_confusion_matrix = confusion_matrix(y_train[:assign_items], train_predictions)
- test_confusion_matrix = confusion_matrix(y_test[:eval_items], test_predictions)
- print("Train Accuracy:", train_accuracy)
- print("Test Accuracy:", test_accuracy)
- print("Train Precision:", train_precision)
- print("Test Precision:", test_precision)
- print("Train Recall:", train_recall)
- print("Test Recall:", test_recall)
- print("Train F1 Score:", train_f1)
- print("Test F1 Score:", test_f1)
- print("Train Confusion Matrix:\n", train_confusion_matrix)
- print("Test Confusion Matrix:\n", test_confusion_matrix)
- # Plot confusion matrices
- plot_confusion_matrix(train_confusion_matrix, np.arange(10)) # np.arange(10)
- plot_confusion_matrix(test_confusion_matrix, np.arange(10)) # np.arange(10)
- return train_features, test_features
- # Function to plot confusion matrix
- def plot_confusion_matrix(confusion_matrix, labels):
- plt.figure(figsize=(10, 8))
- sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
- plt.xlabel('Predicted Labels')
- plt.ylabel('True Labels')
- plt.title('Confusion Matrix')
- plt.show()
- # Example usage
- train_features, test_features = test_snn(train_items=500, assign_items=100, eval_items=100)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement