Guest User

Untitled

a guest
Mar 24th, 2018
102
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.95 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": null,
  6. "metadata": {
  7. "collapsed": true
  8. },
  9. "outputs": [],
  10. "source": [
  11. "#deducing in-, output and on-top layer vgg16 is made out of 5 blocks with 18 layers. below list marks at what position\n",
  12. "#a block begins\n",
  13. "blocks=[18, 14, 10, 6, 3, 0]\n",
  14. "\n",
  15. "def trainVGG16(vgg16Weights,augumentation, i): \n",
  16. " \n",
  17. " #do not use pre trained weights if vgg16Weights is FALSE\n",
  18. " if vgg16Weights:\n",
  19. " model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))\n",
  20. " useWeight='Weights'\n",
  21. " else:\n",
  22. " model = applications.VGG16(include_top=False, input_shape=(150, 150, 3))\n",
  23. " useWeight='NoWeights'\n",
  24. " print('Model loaded.')\n",
  25. "\n",
  26. " # build a classifier model to put on top of the convolutional model\n",
  27. " top_model = Sequential()\n",
  28. " top_model.add(Flatten(input_shape=model.output_shape[1:]))\n",
  29. " top_model.add(Dense(256, activation='relu'))\n",
  30. " top_model.add(Dropout(0.5))\n",
  31. " top_model.add(Dense(1, activation='sigmoid'))\n",
  32. "\n",
  33. " # note that it is necessary to start with a fully-trained\n",
  34. " # classifier, including the top classifier,\n",
  35. " # in order to successfully do fine-tuning\n",
  36. " top_model.load_weights(top_model_weights_path)\n",
  37. "\n",
  38. " # add the model on top of the convolutional base\n",
  39. " model = Model(inputs=model.input, outputs=top_model(model.output))#model.add(top_model)\n",
  40. "\n",
  41. " # set the first 25 layers (up to the last conv block)\n",
  42. " # to non-trainable (weights will not be updated)\n",
  43. " for layer in model.layers[:i]:\n",
  44. " layer.trainable = False \n",
  45. " \n",
  46. "\n",
  47. " # compile the model with a SGD/momentum optimizer\n",
  48. " # and a very slow learning rate.\n",
  49. " model.compile(loss='binary_crossentropy',\n",
  50. " optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),\n",
  51. " metrics=['accuracy'])\n",
  52. "\n",
  53. "\n",
  54. " #do not sample additional pictures if augumenation is FALSE\n",
  55. " if (augumentation):\n",
  56. " train_datagen = ImageDataGenerator(\n",
  57. " rescale=1. / 255,\n",
  58. " shear_range=0.2,\n",
  59. " zoom_range=0.2,\n",
  60. " horizontal_flip=True)\n",
  61. " useAug='Augumentation'\n",
  62. " else:\n",
  63. " train_datagen = ImageDataGenerator(rescale=1. / 255)\n",
  64. " useAug='NoAugumentation'\n",
  65. "\n",
  66. " currentModel='vgg16_'+useWeight+'_'+useAug+'_trainedLast_'+str(18-i)\n",
  67. " print('Do '+currentModel) \n",
  68. " \n",
  69. " #load batches\n",
  70. " train_generator = train_datagen.flow_from_directory(\n",
  71. " train_data_dir,\n",
  72. " target_size=(img_height, img_width),\n",
  73. " batch_size=batch_size,\n",
  74. " class_mode='binary')\n",
  75. " \n",
  76. " validation_generator = test_datagen.flow_from_directory(\n",
  77. " validation_data_dir,\n",
  78. " target_size=(img_height, img_width),\n",
  79. " batch_size=batch_size,\n",
  80. " class_mode='binary') \n",
  81. " \n",
  82. " #Model callback, dave model after finish\n",
  83. " checkpoint = ModelCheckpoint(currentModel+'_Layers_.h5', monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1)\n",
  84. " \n",
  85. " #stop current run after 15 epochs without acc increase\n",
  86. " early = EarlyStopping(monitor='val_acc', min_delta=0, patience=15, verbose=1, mode='auto') \n",
  87. "\n",
  88. " # fine-tune the model\n",
  89. " hist=model.fit_generator(\n",
  90. " train_generator,\n",
  91. " samples_per_epoch=nb_train_samples,\n",
  92. " epochs=epochs,\n",
  93. " validation_data=validation_generator,\n",
  94. " nb_val_samples=nb_validation_samples,\n",
  95. " callbacks = [checkpoint, early])\n",
  96. " #Save History \n",
  97. " with open('hist_'+currentModel, 'wb') as file_pi:\n",
  98. " pickle.dump(hist.history, file_pi)\n",
  99. " \n",
  100. " clear_output(wait=True)\n",
  101. " \n",
  102. " \n",
  103. "#train the network with and without weights, with and without augumgentation and for all blocks\n",
  104. "for i in blocks: \n",
  105. " trainVGG16(True,True,i)\n",
  106. " trainVGG16(False,True,i)\n",
  107. " trainVGG16(False,False,i)\n",
  108. " trainVGG16(True,False,i)\n",
  109. " "
  110. ]
  111. }
  112. ],
  113. "metadata": {
  114. "kernelspec": {
  115. "display_name": "Python 3",
  116. "language": "python",
  117. "name": "python3"
  118. },
  119. "language_info": {
  120. "codemirror_mode": {
  121. "name": "ipython",
  122. "version": 3
  123. },
  124. "file_extension": ".py",
  125. "mimetype": "text/x-python",
  126. "name": "python",
  127. "nbconvert_exporter": "python",
  128. "pygments_lexer": "ipython3",
  129. "version": "3.6.3"
  130. }
  131. },
  132. "nbformat": 4,
  133. "nbformat_minor": 2
  134. }
Add Comment
Please, Sign In to add comment