Guest User

Untitled

a guest
Jan 19th, 2018
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.42 KB | None | 0 0
  1. import numpy as np
  2. from keras.datasets import mnist
  3. from keras.utils import np_utils
  4. from keras.models import Sequential
  5. from keras.layers import Dense,Dropout,Flatten,Conv2D,MaxPooling2D
  6. import matplotlib.pyplot as plt
  7. import pandas as pd
  8.  
  9. #avoid warning of AV,AVX....
  10. import os
  11. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  12.  
  13. np.random.seed(10)
  14.  
  15. def main():
  16. ((x_train,y_train), (x_test,y_test)) = mnist.load_data()
  17.  
  18. x_train4D = x_train.reshape(x_train.shape[0],28,28,1).astype('float32')
  19. x_test4D = x_test.reshape(x_test.shape[0],28,28,1).astype('float32')
  20.  
  21. x_train4D_normalize = x_train4D / 255
  22. x_test4D_normalize = x_test4D / 255
  23.  
  24. y_TrainOntHot = np_utils.to_categorical(y_train)
  25. y_TestOneHot = np_utils.to_categorical(y_test)
  26.  
  27. model = Sequential()
  28. model.add(Conv2D(filters=16,
  29. kernel_size=(5,5),
  30. padding='same',
  31. input_shape=(28,28,1),
  32. activation='relu'))
  33.  
  34. model.add(MaxPooling2D(pool_size=(2, 2)))
  35. model.add(Conv2D(filters=36,
  36. kernel_size=(5,5),
  37. padding='same',
  38. input_shape=(28,28,1),
  39. activation='relu'))
  40. model.add(MaxPooling2D(pool_size=(2, 2)))
  41. model.add(Dropout(0.25))
  42. model.add(Flatten())
  43. model.add(Dense(128, activation='relu'))
  44. model.add(Dropout(0.5))
  45. model.add(Dense(10,activation='softmax'))
  46.  
  47. # show model's layers
  48. print(model.summary())
  49. model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
  50. # decide optimization method
  51. train_history=model.fit(x=x_train4D_normalize,
  52. y=y_TrainOntHot,
  53. validation_split=0.2,
  54. epochs=10,
  55. batch_size=300,
  56. verbose=2)
  57.  
  58. show_train_history(train_history,'acc','val_acc')
  59.  
  60. '''not complete '''
  61. # try:
  62. # model.load_weights("SaveModels/MNIST_CNN.h5")
  63. # print("load weights.")
  64. # except:
  65. # model.save_weights("SaveModels/MNIST_CNN.h5")
  66. # print("Save weights.")
  67.  
  68. #show accuracy
  69. scores = model.evaluate(x_test4D_normalize, y_TestOneHot)
  70. print(scores[1])
  71. #show confuse_matrix
  72. confuse_matrix = pd.crosstab(y_test_label,
  73. prediction,
  74. rownames=['label'],
  75. colnames=['predict'])
  76. print(confuse_matrix)
  77.  
  78. prediction = model.predict_classes(x_test4D_normalize)
  79. plot_images_labels_prediction(x_test,y_test,prediction,idx=340)
  80.  
  81.  
  82. #show graph of accuracy
  83. def show_train_history(train_history,train,validation):
  84. plt.plot(train_history.history[train])
  85. plt.plot(train_history.history[validation])
  86. plt.title('Train History')
  87. plt.ylabel(train)
  88. plt.xlabel('Epoch')
  89. plt.legend(['train', 'validation'], loc='upper left')
  90. plt.show()
  91.  
  92. #show MNIST recongition result
  93. def plot_images_labels_prediction(images,labels,prediction,idx,num=10):
  94.  
  95. fig = plt.gcf()
  96. fig.set_size_inches(12, 14)
  97. if num>25: num=25
  98. for i in range(0,num):
  99. ax=plt.subplot(5,5, 1+i)
  100. ax.imshow(images[idx],cmap='binary')
  101. title="label="+str(labels[idx])
  102. if len(prediction)>0:
  103. title+=",predict="+str(prediction[idx])
  104.  
  105. ax.set_title(title,fontsize=10)
  106. ax.set_xticks([]);ax.set_yticks([])
  107. idx+=1
  108. plt.show()
  109.  
  110. if __name__ == '__main__':
  111. main()
Add Comment
Please, Sign In to add comment