Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- from sklearn.model_selection import train_test_split
- from tensorflow import keras
- from tensorflow.keras.layers import Activation, Conv2D, Flatten, Dropout, MaxPooling2D, BatchNormalization, Dense
- from tensorflow.keras import optimizers
- import numpy as np
- import matplotlib.pyplot as plt
- import pandas as pd
- from tensorflow_core.python.keras.layers.pooling import MaxPool2D
- import preprocessing as pre
- print(tf.__version__)
- train_x = pd.read_pickle('data/train_max_x')
- train_y = pd.read_csv("data/train_max_y.csv").Label.values
- train_x, val_x, train_y, val_y = train_test_split(train_x, train_y, test_size = 0.2, random_state=1, shuffle=True)
- #x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], x_train.shape[2])
- print("Pre-processing")
- # # Apply image preprocessing
- train_x = pre.binarize(train_x, threshold=230)
- val_x = pre.binarize(val_x, threshold=230)
- # print("Getting figure")
- # plt.figure(figsize=(10,10))
- # for i in range(20):
- # plt.subplot(5,5,i+1)
- # plt.xticks([])
- # plt.yticks([])
- # plt.grid(False)
- # plt.imshow(train_x[i], cmap=plt.cm.binary)
- # plt.xlabel(train_y[i])
- # plt.show()
- print("Input")
- # Dimensions of image
- input_shape = (128,128,1)
- num_classes = 10
- train_x = train_x.reshape(train_x.shape[0], train_x.shape[1], train_x.shape[2], 1)
- val_x = val_x.reshape(val_x.shape[0], val_x.shape[1], val_x.shape[2], 1)
- model = keras.Sequential()
- model.add(Conv2D(32, kernel_size=(3, 3),activation='relu',kernel_initializer='he_normal',input_shape=input_shape))
- model.add(MaxPool2D((2, 2)))
- model.add(Dropout(0.20))
- model.add(Conv2D(64, (3, 3), activation='relu',padding='same',kernel_initializer='he_normal'))
- model.add(MaxPool2D(pool_size=(2, 2)))
- model.add(Dropout(0.25))
- model.add(Conv2D(128, (3, 3), activation='relu',padding='same',kernel_initializer='he_normal'))
- model.add(MaxPool2D(pool_size=(2, 2)))
- model.add(Dropout(0.25))
- model.add(Conv2D(256, (3, 3), activation='relu',padding='same',kernel_initializer='he_normal'))
- model.add(MaxPool2D(pool_size=(2, 2)))
- model.add(Dropout(0.25))
- model.add(Conv2D(512, (3, 3), activation='relu',padding='same',kernel_initializer='he_normal'))
- model.add(MaxPool2D(pool_size=(2, 2)))
- model.add(Flatten())
- model.add(Dense(512, activation='relu'))
- model.add(BatchNormalization())
- model.add(Dropout(0.5))
- model.add(Dense(num_classes, activation='softmax'))
- model.compile(optimizer=tf.keras.optimizers.Adam(),
- loss='sparse_categorical_crossentropy',
- metrics=['accuracy'])
- # model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.02, decay=1e-6, momentum=0.9, nesterov=True),
- # loss='sparse_categorical_crossentropy',
- # metrics=['accuracy'])
- print("Fitting")
- history = model.fit(train_x, train_y, epochs=20, batch_size=128, verbose=2, validation_data=(val_x, val_y))
- model.save('models/cnn_model_AlexNet.h5')
- print("Saved")
- print("Evaluating")
- train_loss, train_acc = model.evaluate(train_x, train_y, verbose=2)
- val_loss, val_acc = model.evaluate(val_x, val_y, verbose=2)
- print('\nTrain accuracy:', train_acc)
- print('\nValidation accuracy:', val_acc)
- # print("Predicting on test set")
- # test_x = pd.read_pickle('test_max_x')
- # test_x = test_x.reshape(test_x.shape[0], test_x.shape[1], test_x.shape[2], 1)
- # predictions = np.argmax(model.predict(test_x), axis=1)
- # df = pd.DataFrame(pd.Series(predictions))
- # df.to_csv('data/predict_test_model_AlexNet.csv')
- print(history.history.keys())
- # "Accuracy"
- plt.plot(history.history['acc'])
- plt.plot(history.history['val_acc'])
- plt.title('model accuracy')
- plt.ylabel('accuracy')
- plt.xlabel('epoch')
- plt.legend(['train', 'validation'], loc='upper left')
- plt.show()
- # "Loss"
- plt.plot(history.history['loss'])
- plt.plot(history.history['val_loss'])
- plt.title('model loss')
- plt.ylabel('loss')
- plt.xlabel('epoch')
- plt.legend(['train', 'validation'], loc='upper left')
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement