Guest User

Untitled

a guest
Jul 22nd, 2018
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.10 KB | None | 0 0
  1. from keras import applications
  2. from keras.preprocessing.image import ImageDataGenerator
  3. from keras import optimizers
  4. from keras.models import Sequential, Model
  5. from keras.layers import Dropout, Flatten, Dense, GlobalAveragePooling2D
  6. from keras import backend as k
  7. from keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard, EarlyStopping
  8. from keras.models import load_model
  9. import os
  10. import pickle
  11. from keras.models import model_from_json
  12. import matplotlib.pyplot as plt
  13.  
  14. image_width, image_height= 256, 256
  15.  
  16. nb_train_samples= 11000
  17. nb_validation_sample=2000
  18. batch_size = 8
  19.  
  20. model = applications.VGG19(weights= "imagenet", include_top=False, input_shape=(image_height, image_width,3))
  21.  
  22. #freezing the first 5 layers, to avoid excess computation.
  23. for layer in model.layers[:5]:
  24. layer.trainable = False
  25.  
  26. #creating a fully connected layer.
  27. x=model.output
  28. x=Flatten()(x)
  29. x=Dense(1024, activation="relu")(x)
  30. x=Dropout(0.5)(x)
  31. x=Dense(384, activation="relu")(x)
  32. x=Dropout(0.5)(x)
  33. x=Dense(96, activation="relu")(x)
  34. x=Dropout(0.5)(x)
  35. #Dense(num, activation="softmax"), the num signifies the number of classes in the dataset.
  36. predictions = Dense(30, activation="softmax")(x)
  37.  
  38.  
  39. model_final =Model(input=model.input, output=predictions)
  40. #model_final = load_model("weights_VGG.h5")
  41. model_final.compile(loss="categorical_crossentropy", optimizer=optimizers.nadam(lr=0.00001), metrics=["accuracy"])
  42.  
  43. train_datagen = ImageDataGenerator(rescale = 1./255,
  44. shear_range = 0.2,
  45. zoom_range = 0.2,
  46. horizontal_flip = True,
  47. fill_mode="nearest",
  48. width_shift_range=0.3,
  49. height_shift_range=0.3,
  50. rotation_range=30)
  51.  
  52. test_datagen = ImageDataGenerator(rescale = 1./255,
  53. horizontal_flip = True,
  54. fill_mode = "nearest",
  55. zoom_range = 0.3,
  56. width_shift_range = 0.3,
  57. height_shift_range=0.3,
  58. rotation_range=30)
  59.  
  60. training_set = train_datagen.flow_from_directory('./HE_Chal',
  61. target_size = (256, 256),
  62. batch_size = 8,
  63. class_mode = 'categorical')
  64.  
  65. test_set = test_datagen.flow_from_directory('./Validation',
  66. target_size = (256, 256),
  67. batch_size = 8,
  68. class_mode = 'categorical')
  69.  
  70. model_final.fit_generator(training_set, steps_per_epoch = 1000,epochs = 80, validation_data = test_set,validation_steps=1000)
  71. model_json=model_final.to_json()
  72. with open("model.json", "w") as json_file:
  73. json_file.write(model_json)
  74. model_final.save_weights("weights_VGG.h5")
  75. model_final.save("model_27.h5")
  76. #model_final.predict(test_set, batch_size=batch_size)
  77. '''
  78. json_file = open('model.json', 'r')
  79. loaded_model_json = json_file.read()
  80. json_file.close()
  81. loaded_model = model_from_json(loaded_model_json)
  82. # load weights into new model
  83. loaded_model.load_weights("weights_VGG.h5",by_name=True)
  84. print("Loaded model from disk")
  85.  
  86. # evaluate loaded model on test data
  87. loaded_model.compile(loss='categorical_crossentropy', optimizer='nadam', metrics=['accuracy'])
  88.  
  89. #print(loaded_model.summary())
  90. loaded_model.fit_generator(training_set, steps_per_epoch = 1000,epochs = 100, validation_data = test_set,validation_steps=1000)
  91. #score = loaded_model.evaluate(training_set,test_set , verbose=0)
  92. '''
  93. print(history.history.keys())
  94. # summarize history for accuracy
  95. plt.plot(history.history['acc'])
  96. plt.plot(history.history['val_acc'])
  97. plt.title('model accuracy')
  98. plt.ylabel('accuracy')
  99. plt.xlabel('epoch')
  100. plt.legend(['train', 'test'], loc='upper left')
  101. plt.show()
  102. # summarize history for loss
  103. plt.plot(history.history['loss'])
  104. plt.plot(history.history['val_loss'])
  105. plt.title('model loss')
  106. plt.ylabel('loss')
  107. plt.xlabel('epoch')
  108. plt.legend(['train', 'test'], loc='upper left')
  109. plt.show()
Add Comment
Please, Sign In to add comment