Advertisement
Guest User

Untitled

a guest
Dec 9th, 2019
203
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.89 KB | None | 0 0
  1. # baseline model for the dogs vs cats dataset
  2. import sys
  3. from os import listdir
  4. from random import shuffle
  5. import numpy as np
  6. from matplotlib import pyplot as plt
  7. from matplotlib.image import imread
  8. from keras.models import Sequential
  9. from keras.layers import Conv2D
  10. from keras.layers import MaxPooling2D
  11. from keras.layers import Dense
  12. from keras.layers import Flatten
  13. from keras.optimizers import SGD
  14. from keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
  15.  
  16. # define cnn model
  17. def define_model():
  18.     model = Sequential()
  19.     model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform',
  20.                      padding='same', input_shape=(28, 28, 3)))
  21.     model.add(MaxPooling2D((2, 2)))
  22.     model.add(Flatten())
  23.     model.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))
  24.     model.add(Dense(6, activation='sigmoid'))
  25.     # compile model
  26.     opt = SGD(lr=0.001, momentum=0.9)
  27.     model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
  28.     return model
  29.  
  30. # plot diagnostic learning curves
  31. def summarize_diagnostics(history):
  32.     # plot loss
  33.     plt.subplot(2, 1, 1)
  34.     plt.title('Cross Entropy Loss')
  35.     plt.plot(history.history['loss'], color='blue', label='train')
  36.     plt.plot(history.history['val_loss'], color='orange', label='test')
  37.     plt.legend()
  38.     # plot accuracy
  39.     plt.subplot(2, 1, 2)
  40.     plt.title('Classification Accuracy')
  41.     plt.plot(history.history['acc'], color='blue', label='train')
  42.     plt.plot(history.history['val_acc'], color='orange', label='test')
  43.     plt.legend()
  44.     # save plot to file
  45.     filename = sys.argv[0].split('/')[-1]
  46.     plt.savefig(filename + '_plot.png')
  47.     plt.tight_layout()
  48.     plt.show()
  49.     plt.close()
  50.  
  51. # run the test harness for evaluating a model
  52. def run_test_harness():
  53.     # define model
  54.     model = define_model()
  55.     # create data generator
  56.     datagen = ImageDataGenerator(rescale=1.0/255.0)
  57.     # prepare iterators
  58.     train_it = datagen.flow_from_directory('quickdraw/train/',
  59.                                            class_mode='categorical',
  60.                                            batch_size=64,
  61.                                            target_size=(28, 28))
  62.     test_it = datagen.flow_from_directory('quickdraw/test/',
  63.                                           class_mode='categorical',
  64.                                           batch_size=64,
  65.                                           target_size=(28, 28))
  66.     labels = train_it.class_indices
  67.     # switch keys and values
  68.     labels = {v: k for k, v in labels.items()}
  69.     print(labels)
  70.     # fit model
  71.     history = model.fit_generator(train_it,
  72.                                   steps_per_epoch=len(train_it),
  73.                                   validation_data=test_it,
  74.                                   validation_steps=len(test_it),
  75.                                   epochs=50,
  76.                                   verbose=1)
  77.     # evaluate model
  78.     _, acc = model.evaluate_generator(test_it,
  79.                                       steps=len(test_it),
  80.                                       verbose=1)
  81.  
  82.     print(f'Accuracy on test images: {acc * 100:.2f} %')
  83.     # learning curves
  84.     summarize_diagnostics(history)
  85.    
  86.     # load 5 random images and print the output
  87.     images = []
  88.     for l in labels.values():
  89.         folder = f'quickdraw/test/{l}/'
  90.         images += [folder + f for f in listdir(folder)]
  91.     shuffle(images)
  92.     for i in range(10):
  93.         filename = images.pop(0)
  94.         plt_img = imread(filename)
  95.         # plot raw pixel data
  96.         plt.imshow(plt_img, cmap="gray")
  97.         plt.show()
  98.        
  99.         img = load_img(filename, target_size=(28, 28))
  100.         x = img_to_array(img)
  101.         x = np.expand_dims(x, axis=0)
  102.         pred = model.predict(x)
  103.         print(pred)
  104.  
  105. # entry point, run the test harness
  106. run_test_harness()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement