Advertisement
Guest User

Keras image classifier using VGG16

a guest
Apr 22nd, 2017
673
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.73 KB | None | 0 0
  1. image_size = 128
  2. mean_pixel = [103.939, 116.779, 123.68]
  3.  
  4.  
  5. def get_img(animal, index):
  6.     file_path = 'data/train/%ss/%s%03d.jpg' % (animal, animal, index)
  7.     img = cv2.resize(cv2.imread(file_path), (image_size, image_size))
  8.     for c in range(3):
  9.         img[:, :, c] = img[:, :, c] - mean_pixel[c]
  10.     return img
  11.  
  12.  
  13. def get_data():
  14.     X = []
  15.     Y = []
  16.     for i in range(0, 10000):
  17.         X.append(get_img("cat", i))
  18.         Y.append([0, 1, 0, 0, 0, 0, 0, 0])
  19.         X.append(get_img("dog", i))
  20.         Y.append([1, 0, 0, 0, 0, 0, 0, 0])
  21.     return np.array(X), np.array(Y)
  22.  
  23.  
  24. if __name__ == "__main__":
  25.     vgg16 = VGG16(weights='imagenet', include_top=False)
  26.  
  27.     input = Input(shape=(image_size, image_size, 3), name='image_input')
  28.     output_vgg16_conv = vgg16(input)
  29.  
  30.     x = Flatten(name='flatten')(output_vgg16_conv)
  31.     x = Dense(4096, activation='relu', name='fc1')(x)
  32.     x = Dense(4096, activation='relu', name='fc2')(x)
  33.     x = Dense(8, activation='softmax', name='predictions')(x)
  34.  
  35.     my_model = Model(input=input, output=x)
  36.     my_model.summary()
  37.  
  38.     weights_path = "my_vgg16_weights.h5"
  39.  
  40.     if os.path.isfile(weights_path):
  41.         my_model.load_weights(weights_path)
  42.  
  43.     checkpoint = ModelCheckpoint(weights_path, monitor='val_acc', verbose=0, save_best_only=False, mode='max')
  44.     my_model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])
  45.  
  46.     X, Y = get_data()
  47.     my_model.fit(X, Y, epochs=20, verbose=1, callbacks=[checkpoint], shuffle=True)
  48.  
  49.     cat_img = get_img("cat", 10001)
  50.     cat_im = np.expand_dims(cat_img, axis=0)
  51.     print(my_model.predict(cat_im))
  52.  
  53.     dog_img = get_img("dog", 10001)
  54.     dog_im = np.expand_dims(dog_img, axis=0)
  55.     print(my_model.predict(dog_im))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement