Advertisement
Guest User

Untitled

a guest
Feb 21st, 2018
100
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.04 KB | None | 0 0
  1. from keras.applications import VGG16
  2. from keras.layers import Conv2D
  3. #Load the VGG model
  4. image_size = 224
  5. vgg_conv = VGG16(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))
  6.  
  7. for layer in vgg_conv.layers[:-4]:
  8.     layer.trainable = False
  9.  
  10. from keras import models
  11. from keras import layers
  12. from keras import optimizers
  13.  
  14. # Create the model
  15. model = models.Sequential()
  16.  
  17. # Add the vgg convolutional base model
  18. model.add(vgg_conv)
  19.  
  20. # Add new layers
  21. model.add(layers.Flatten())
  22. model.add(layers.Dense(1024, activation='relu'))
  23. model.add(layers.Dropout(0.5))
  24. model.add(layers.Dense(6, activation='softmax'))
  25.  
  26. from keras.preprocessing.image import ImageDataGenerator
  27. train_datagen = ImageDataGenerator(
  28.       rescale=1./255,
  29.       rotation_range=20,
  30.       width_shift_range=0.2,
  31.       height_shift_range=0.2,
  32.       horizontal_flip=True,
  33.       fill_mode='nearest')
  34.  
  35. validation_datagen = ImageDataGenerator(rescale=1./255)
  36.  
  37. train_dir = '/home/nikita/Develop/ML/Marcel-Train'
  38. validation_dir = '/home/nikita/Develop/ML/Marcel-Test'
  39.  
  40. train_batchsize = 100
  41. val_batchsize = 10
  42.  
  43. train_generator = train_datagen.flow_from_directory(
  44.         train_dir,
  45.         target_size=(image_size, image_size),
  46.         batch_size=train_batchsize,
  47.         class_mode='categorical')
  48.  
  49. validation_generator = validation_datagen.flow_from_directory(
  50.         validation_dir,
  51.         target_size=(image_size, image_size),
  52.         batch_size=val_batchsize,
  53.         class_mode='categorical',
  54.         shuffle=False)
  55.  
  56. # Compile the model
  57. model.compile(loss='categorical_crossentropy',
  58.               optimizer=optimizers.RMSprop(lr=1e-4),
  59.               metrics=['acc'])
  60. # Train the model
  61. history = model.fit_generator(
  62.       train_generator,
  63.       steps_per_epoch=train_generator.samples/train_generator.batch_size ,
  64.       epochs=5,
  65.       validation_data=validation_generator,
  66.       validation_steps=validation_generator.samples/validation_generator.batch_size,
  67.       verbose=1)
  68.  
  69. # Save the model
  70. model.save('my_model.h5')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement