Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "#deducing in-, output and on-top layer vgg16 is made out of 5 blocks with 18 layers. below list marks at what position\n",
- "#a block begins\n",
- "blocks=[18, 14, 10, 6, 3, 0]\n",
- "\n",
- "def trainVGG16(vgg16Weights,augumentation, i): \n",
- " \n",
- " #do not use pre trained weights if vgg16Weights is FALSE\n",
- " if vgg16Weights:\n",
- " model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))\n",
- " useWeight='Weights'\n",
- " else:\n",
- " model = applications.VGG16(include_top=False, input_shape=(150, 150, 3))\n",
- " useWeight='NoWeights'\n",
- " print('Model loaded.')\n",
- "\n",
- " # build a classifier model to put on top of the convolutional model\n",
- " top_model = Sequential()\n",
- " top_model.add(Flatten(input_shape=model.output_shape[1:]))\n",
- " top_model.add(Dense(256, activation='relu'))\n",
- " top_model.add(Dropout(0.5))\n",
- " top_model.add(Dense(1, activation='sigmoid'))\n",
- "\n",
- " # note that it is necessary to start with a fully-trained\n",
- " # classifier, including the top classifier,\n",
- " # in order to successfully do fine-tuning\n",
- " top_model.load_weights(top_model_weights_path)\n",
- "\n",
- " # add the model on top of the convolutional base\n",
- " model = Model(inputs=model.input, outputs=top_model(model.output))#model.add(top_model)\n",
- "\n",
- " # set the first 25 layers (up to the last conv block)\n",
- " # to non-trainable (weights will not be updated)\n",
- " for layer in model.layers[:i]:\n",
- " layer.trainable = False \n",
- " \n",
- "\n",
- " # compile the model with a SGD/momentum optimizer\n",
- " # and a very slow learning rate.\n",
- " model.compile(loss='binary_crossentropy',\n",
- " optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),\n",
- " metrics=['accuracy'])\n",
- "\n",
- "\n",
- " #do not sample additional pictures if augumenation is FALSE\n",
- " if (augumentation):\n",
- " train_datagen = ImageDataGenerator(\n",
- " rescale=1. / 255,\n",
- " shear_range=0.2,\n",
- " zoom_range=0.2,\n",
- " horizontal_flip=True)\n",
- " useAug='Augumentation'\n",
- " else:\n",
- " train_datagen = ImageDataGenerator(rescale=1. / 255)\n",
- " useAug='NoAugumentation'\n",
- "\n",
- " currentModel='vgg16_'+useWeight+'_'+useAug+'_trainedLast_'+str(18-i)\n",
- " print('Do '+currentModel) \n",
- " \n",
- " #load batches\n",
- " train_generator = train_datagen.flow_from_directory(\n",
- " train_data_dir,\n",
- " target_size=(img_height, img_width),\n",
- " batch_size=batch_size,\n",
- " class_mode='binary')\n",
- " \n",
- " validation_generator = test_datagen.flow_from_directory(\n",
- " validation_data_dir,\n",
- " target_size=(img_height, img_width),\n",
- " batch_size=batch_size,\n",
- " class_mode='binary') \n",
- " \n",
- " #Model callback, dave model after finish\n",
- " checkpoint = ModelCheckpoint(currentModel+'_Layers_.h5', monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1)\n",
- " \n",
- " #stop current run after 15 epochs without acc increase\n",
- " early = EarlyStopping(monitor='val_acc', min_delta=0, patience=15, verbose=1, mode='auto') \n",
- "\n",
- " # fine-tune the model\n",
- " hist=model.fit_generator(\n",
- " train_generator,\n",
- " samples_per_epoch=nb_train_samples,\n",
- " epochs=epochs,\n",
- " validation_data=validation_generator,\n",
- " nb_val_samples=nb_validation_samples,\n",
- " callbacks = [checkpoint, early])\n",
- " #Save History \n",
- " with open('hist_'+currentModel, 'wb') as file_pi:\n",
- " pickle.dump(hist.history, file_pi)\n",
- " \n",
- " clear_output(wait=True)\n",
- " \n",
- " \n",
- "#train the network with and without weights, with and without augumgentation and for all blocks\n",
- "for i in blocks: \n",
- " trainVGG16(True,True,i)\n",
- " trainVGG16(False,True,i)\n",
- " trainVGG16(False,False,i)\n",
- " trainVGG16(True,False,i)\n",
- " "
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
Add Comment
Please, Sign In to add comment