Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Import necessary libraries
- import numpy as np
- import matplotlib.pyplot as plt
- from tensorflow.keras.datasets import mnist
- from tensorflow.keras.models import Sequential
- from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, SimpleRNN
- from tensorflow.keras.utils import to_categorical
- from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
- import seaborn as sns
- from brian2 import *
- # Load MNIST dataset
- (x_train, y_train), (x_test, y_test) = mnist.load_data()
- # Preprocess the data
- x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32') / 255
- x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float32') / 255
- y_train = to_categorical(y_train, 10)
- y_test = to_categorical(y_test, 10)
- num_classes = 10
- # Define the CNN model
- cnn_model = Sequential([
- Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
- MaxPooling2D(pool_size=(2, 2)),
- Flatten(),
- Dense(128, activation='relu'),
- Dense(10, activation='softmax')
- ])
- cnn_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
- # Train the CNN model
- cnn_history = cnn_model.fit(x_train, y_train, batch_size=128, epochs=20, validation_split=0.1)
- # Evaluate the CNN model
- cnn_loss, cnn_accuracy = cnn_model.evaluate(x_test, y_test)
- cnn_predictions = cnn_model.predict(x_test)
- cnn_predictions = np.argmax(cnn_predictions, axis=1)
- cnn_true_labels = np.argmax(y_test, axis=1)
- cnn_precision = precision_score(cnn_true_labels, cnn_predictions, average='weighted')
- cnn_recall = recall_score(cnn_true_labels, cnn_predictions, average='weighted')
- cnn_f1 = f1_score(cnn_true_labels, cnn_predictions, average='weighted')
- cnn_conf_matrix = confusion_matrix(cnn_true_labels, cnn_predictions)
- print(f'CNN Test Accuracy: {cnn_accuracy}')
- print(f'CNN Precision: {cnn_precision}')
- print(f'CNN Recall: {cnn_recall}')
- print(f'CNN F1 Score: {cnn_f1}')
- # Plot confusion matrix for CNN
- plt.figure(figsize=(8, 6))
- sns.heatmap(cnn_conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=range(num_classes), yticklabels=range(num_classes))
- plt.title('Confusion Matrix - CNN')
- plt.xlabel('Predicted')
- plt.ylabel('True')
- plt.show()
- # Define the RNN model
- rnn_model = Sequential([
- SimpleRNN(128, input_shape=(28, 28), activation='relu', return_sequences=False),
- Dense(10, activation='softmax')
- ])
- rnn_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
- # Train the RNN model
- rnn_history = rnn_model.fit(x_train, y_train, batch_size=128, epochs=20, validation_split=0.1)
- # Evaluate the RNN model
- rnn_loss, rnn_accuracy = rnn_model.evaluate(x_test, y_test)
- rnn_predictions = rnn_model.predict(x_test)
- rnn_predictions = np.argmax(rnn_predictions, axis=1)
- rnn_true_labels = np.argmax(y_test, axis=1)
- rnn_precision = precision_score(rnn_true_labels, rnn_predictions, average='weighted')
- rnn_recall = recall_score(rnn_true_labels, rnn_predictions, average='weighted')
- rnn_f1 = f1_score(rnn_true_labels, rnn_predictions, average='weighted')
- rnn_conf_matrix = confusion_matrix(rnn_true_labels, rnn_predictions)
- print(f'RNN Test Accuracy: {rnn_accuracy}')
- print(f'RNN Precision: {rnn_precision}')
- print(f'RNN Recall: {rnn_recall}')
- print(f'RNN F1 Score: {rnn_f1}')
- # Plot confusion matrix for RNN
- plt.figure(figsize=(8, 6))
- sns.heatmap(rnn_conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=range(num_classes), yticklabels=range(num_classes))
- plt.title('Confusion Matrix - RNN')
- plt.xlabel('Predicted')
- plt.ylabel('True')
- plt.show()
- # Plot training history for CNN and RNN
- plt.plot(cnn_history.history['accuracy'], label='CNN Train Accuracy')
- plt.plot(cnn_history.history['val_accuracy'], label='CNN Validation Accuracy')
- plt.plot(rnn_history.history['accuracy'], label='RNN Train Accuracy')
- plt.plot(rnn_history.history['val_accuracy'], label='RNN Validation Accuracy')
- plt.title('CNN and RNN Model Training History')
- plt.xlabel('Epoch')
- plt.ylabel('Accuracy')
- plt.legend()
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement