Guest User

Untitled

a guest
Oct 18th, 2017
113
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.46 KB | None | 0 0
  1. ```
  2. from keras.applications.inception_v3 import InceptionV3
  3. from keras.preprocessing import image
  4. from keras.models import Model
  5. from keras.layers import Dense, GlobalAveragePooling2D
  6. from keras import backend as K
  7. from keras.preprocessing.image import ImageDataGenerator
  8. from keras.layers import Input
  9.  
  10. # dimensions of our images.
  11. img_width, img_height = 150, 150
  12.  
  13. train_data_dir = '/Users/michael/testdata/train' #contains two classes cats and dogs
  14. validation_data_dir = '/Users/michael/testdata/validation' #contains two classes cats and dogs
  15.  
  16. nb_train_samples = 1200
  17. nb_validation_samples = 800
  18. nb_epoch = 50
  19.  
  20. # create the base pre-trained model
  21. base_model = InceptionV3(weights='imagenet', include_top=False)
  22.  
  23. # add a global spatial average pooling layer
  24. x = base_model.output
  25. x = GlobalAveragePooling2D()(x)
  26. # let's add a fully-connected layer
  27. x = Dense(1024, activation='relu')(x)
  28. # and a logistic layer -- let's say we have 200 classes
  29. predictions = Dense(200, activation='softmax')(x)
  30.  
  31. # this is the model we will train
  32. model = Model(input=base_model.input, output=predictions)
  33.  
  34. # first: train only the top layers (which were randomly initialized)
  35. # i.e. freeze all convolutional InceptionV3 layers
  36. for layer in base_model.layers:
  37. layer.trainable = False
  38.  
  39. # compile the model (should be done *after* setting layers to non-trainable)
  40. #model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
  41. model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
  42.  
  43. # prepare data augmentation configuration
  44. train_datagen = ImageDataGenerator(
  45. rescale=1./255)#,
  46. # shear_range=0.2,
  47. # zoom_range=0.2,
  48. # horizontal_flip=True)
  49.  
  50. test_datagen = ImageDataGenerator(rescale=1./255)
  51.  
  52. train_generator = train_datagen.flow_from_directory(
  53. train_data_dir,
  54. target_size=(img_width, img_height),
  55. batch_size=16,
  56. class_mode='categorical'
  57. )
  58.  
  59. validation_generator = test_datagen.flow_from_directory(
  60. validation_data_dir,
  61. target_size=(img_width, img_height),
  62. batch_size=16,
  63. class_mode='categorical'
  64. )
  65.  
  66. print "start history model"
  67. history = model.fit_generator(
  68. train_generator,
  69. nb_epoch=nb_epoch,
  70. samples_per_epoch=128,
  71. validation_data=validation_generator,
  72. nb_val_samples=nb_validation_samples) #1020
  73.  
  74. # at this point, the top layers are well trained and we can start fine-tuning
  75. # convolutional layers from inception V3. We will freeze the bottom N layers
  76. # and train the remaining top layers.
  77.  
  78. # let's visualize layer names and layer indices to see how many layers
  79. # we should freeze:
  80. for i, layer in enumerate(base_model.layers):
  81. print(i, layer.name)
  82.  
  83. # we chose to train the top 2 inception blocks, i.e. we will freeze
  84. # the first 172 layers and unfreeze the rest:
  85. for layer in model.layers[:172]:
  86. layer.trainable = False
  87. for layer in model.layers[172:]:
  88. layer.trainable = True
  89.  
  90. # we need to recompile the model for these modifications to take effect
  91. # we use SGD with a low learning rate
  92. from keras.optimizers import SGD
  93. model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy')
  94.  
  95. # we train our model again (this time fine-tuning the top 2 inception blocks
  96. # alongside the top Dense layers
  97. #model.fit_generator(...)
  98. # fine-tune the model
  99. model.fit_generator(
  100. train_generator,
  101. samples_per_epoch=nb_train_samples,
  102. nb_epoch=nb_epoch,
  103. validation_data=validation_generator,
  104. nb_val_samples=nb_validation_samples)
  105.  
  106. model.save("inception_retrained_model1")
  107.  
  108.  
  109. ```
Add Comment
Please, Sign In to add comment