Guest User

Untitled

a guest
Jan 20th, 2019
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.22 KB | None | 0 0
  1. # --- directory structure ---
  2. # image_classification.py
  3. # data:
  4. # class1:
  5. # 1.jpeg
  6. # 2.jpeg
  7. # ...
  8. # class2:
  9. # 1.jpeg
  10. # 2.jpeg
  11. # ...
  12. # class3:
  13. # 1.jpeg
  14. # xyz.jpeg
  15. # ...
  16.  
  17. from keras.applications.inception_v3 import InceptionV3
  18. from keras.preprocessing import image
  19. from keras.models import Model
  20. from keras.layers import GlobalAveragePooling2D, Dense, Dropout, Activation, Flatten
  21. from keras.callbacks import ModelCheckpoint, TensorBoard
  22. from keras.utils import np_utils
  23. from keras import backend as K
  24.  
  25. import cv2
  26. import skimage
  27. import matplotlib.pyplot as plt
  28.  
  29. import numpy as np
  30.  
  31. import os, time
  32. import glob
  33.  
  34. os.environ['KMP_DUPLICATE_LIB_OK']='True'
  35.  
  36. def load_normalized_image(filename, image_size):
  37. return (skimage.transform.resize(plt.imread(filename).astype(float), [image_size, image_size]) - 255.0/2) / 255.0
  38.  
  39. def generator(data_root, images_list, labels_list, batch_size, image_size, class_names):
  40. num_classes = len(class_names)
  41. batch_features = np.zeros((batch_size, image_size, image_size, 3))
  42.  
  43. while True:
  44. labels = []
  45.  
  46. for i in range(batch_size):
  47. index = np.random.choice(len(images_list))
  48. image_path = os.path.join(data_root, images_list[index])
  49. img = load_normalized_image(image_path, image_size)
  50. if len(img.shape) < 3:
  51. img = cv2.merge((img, img, img))
  52. if len(img.shape) > 3:
  53. img = img[:, :, 0:3]
  54. batch_features[i] = img
  55. labels.append(labels_list[index])
  56.  
  57. batch_labels = np_utils.to_categorical(labels, num_classes)
  58. yield batch_features, batch_labels
  59.  
  60. def load_data(data_dirs, filetypes):
  61. data = []
  62. labels = []
  63. for idx, data_dir in enumerate(data_dirs):
  64. files = []
  65. for filetype in filetypes:
  66. files.extend(glob.glob(os.path.join(data_dir, '*.' + filetype)))
  67. data.append(files)
  68. labels.append(np.ones(len(data[idx]), dtype=int) * idx)
  69.  
  70. return data, labels
  71.  
  72. data_directory = 'data'
  73. filetypes = ['jpeg', 'jpg', 'png']
  74.  
  75. checkpoints_directory = 'checkpoints'
  76. log_directory = 'logs'
  77.  
  78. class_names = [os.path.basename(d) for d in glob.glob(os.path.join(data_directory, '*')) if os.path.isdir(d)]
  79. data_dirs = [os.path.join(data_directory, class_name) for class_name in class_names]
  80.  
  81. # Split data into training and valudation sets
  82.  
  83. data, labels = load_data(data_dirs, filetypes)
  84.  
  85. x_train = []
  86. y_train = []
  87. x_test = []
  88. y_test = []
  89. for i in range(len(data)):
  90. x_train.extend(data[i][:160])
  91. y_train.extend(labels[i][:160])
  92.  
  93. x_test.extend(data[i][160:])
  94. y_test.extend(labels[i][160:])
  95.  
  96. # Hyper parameters
  97. image_size = 224
  98. image_channels = 3
  99. batch_size = 8
  100. epochs = 10
  101.  
  102. input_shape = (image_size, image_size, image_channels)
  103. num_classes = len(data)
  104.  
  105. model = InceptionV3(weights='imagenet', include_top=False, input_shape=input_shape)
  106. # model.summary()
  107.  
  108. for layer in model.layers:
  109. layer.trainable = False
  110.  
  111. # Adding custom Layers
  112. x = model.output
  113. x = Flatten()(x)
  114. x = Dense(4096, activation="relu")(x)
  115. x = Dropout(0.5)(x)
  116. x = Dense(4096, activation="relu")(x)
  117. x = Dropout(0.5)(x)
  118. predictions = Dense(num_classes, activation="softmax")(x)
  119.  
  120. # Creating the final model
  121. model_final = Model(input=model.input, output=predictions)
  122. # model_final.summary()
  123.  
  124. # Compile model
  125. model_final.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])
  126.  
  127. # Callbacks
  128. if not os.path.exists(checkpoints_directory):
  129. os.makedirs(checkpoints_directory)
  130.  
  131. chkpoint_callback = ModelCheckpoint(
  132. os.path.join(checkpoints_directory, 'weights.{epoch:02d}.hdf5'),
  133. monitor = 'val_acc',
  134. verbose = 1,
  135. save_best_only = False,
  136. save_weights_only = False,
  137. mode = 'auto',
  138. period = 10)
  139.  
  140. tb_callback = TensorBoard(
  141. log_dir = log_directory,
  142. histogram_freq = 0,
  143. batch_size = batch_size,
  144. write_graph = True,
  145. write_images = True)
  146.  
  147. hist = model_final.fit_generator(
  148. generator('./', x_train, y_train, batch_size, image_size, class_names),
  149. steps_per_epoch = len(x_train),
  150. epochs = epochs,
  151. verbose = 1,
  152. validation_data = generator('./', x_test, y_test, batch_size, image_size, class_names),
  153. validation_steps = len(x_test),
  154. callbacks = [chkpoint_callback, tb_callback])
Add Comment
Please, Sign In to add comment