Advertisement
Guest User

MobileNet Model

a guest
Dec 7th, 2019
370
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.32 KB | None | 0 0
  1. import numpy as np
  2. import os
  3. import numpy
  4. from tensorflow.python.keras.layers import Dense
  5. from tensorflow.python.keras import optimizers
  6. from keras import regularizers
  7. from keras.regularizers import l2
  8. from keras.layers import Dropout
  9. from keras.applications.mobilenet import MobileNet
  10. from keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten, BatchNormalization
  11. from keras.models import Sequential
  12. from keras.preprocessing.image import ImageDataGenerator
  13. from keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler
  14. from keras.optimizers import Adam
  15. import sys
  16. from keras import optimizers
  17.  
  18. # Fixed for our classes
  19. NUM_CLASSES = 4
  20. # Fixed for color images
  21. CHANNELS = 3
  22.  
  23. IMAGE_RESIZE = 224
  24. DENSE_LAYER_ACTIVATION = 'softmax'
  25. OBJECTIVE_FUNCTION = 'categorical_crossentropy'
  26.  
  27. LOSS_METRICS = ['accuracy']
  28. NUM_EPOCHS = 1000
  29. EARLY_STOP_PATIENCE = 50
  30.  
  31.  
  32. BATCH_SIZE_TRAINING = 128
  33. BATCH_SIZE_VALIDATION = 64
  34.  
  35. STEPS_PER_EPOCH_TRAINING = 5096/BATCH_SIZE_TRAINING
  36. STEPS_PER_EPOCH_VALIDATION = 1456/BATCH_SIZE_VALIDATION
  37.  
  38. base_mobilenet_model = MobileNet(include_top = False, weights = None)
  39. model = Sequential()
  40. model.add(Dense(3,input_shape = [IMAGE_RESIZE,IMAGE_RESIZE,3]))
  41. model.add(base_mobilenet_model)
  42. model.add(Dropout(0.5))
  43. model.add(BatchNormalization())
  44. model.add(GlobalAveragePooling2D())
  45.  
  46. # 2nd layer as Dense for 4-class classification,
  47. model.add(Dense(NUM_CLASSES, activation = DENSE_LAYER_ACTIVATION))
  48.  
  49. model.summary()
  50.  
  51. opt = optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.5, nesterov=True)
  52. model.compile(optimizer = opt, loss = OBJECTIVE_FUNCTION,  metrics =  LOSS_METRICS)
  53.  
  54. image_size = IMAGE_RESIZE
  55. shift = 0.2
  56. data_generator = ImageDataGenerator(rescale=1.0/255.0,
  57.                                    width_shift_range=shift,
  58.                                    height_shift_range=shift,
  59.                                    horizontal_flip=True,
  60.                                    vertical_flip=True,
  61.                                    rotation_range=90,
  62.                                    brightness_range=[0.2,1.0],
  63.                                    zoom_range=[0.5,1.0])
  64.                                                            
  65. data_generator2 = ImageDataGenerator(rescale=1.0/255.0)                            
  66.  
  67. train_generator = data_generator.flow_from_directory(
  68.         'trainset/',
  69.         target_size=(image_size, image_size),
  70.         batch_size=BATCH_SIZE_TRAINING,
  71.         class_mode='categorical')
  72.        
  73. validation_generator = data_generator2.flow_from_directory(
  74.         'testset/',
  75.         target_size=(image_size, image_size),
  76.         batch_size=BATCH_SIZE_VALIDATION,
  77.         class_mode='categorical')
  78.  
  79. filepath="weights-improvement-{epoch:02d}-vacc:{val_accuracy:.2f}-tacc:{accuracy:.2f}.hdf5"
  80.        
  81. cb_early_stopper = EarlyStopping(monitor = 'val_accuracy', mode='max', verbose=1, patience = EARLY_STOP_PATIENCE)
  82. cb_checkpointer = ModelCheckpoint(filepath = filepath, monitor = 'val_accuracy', save_best_only = False, mode = 'auto')
  83.      
  84. fit_history = model.fit_generator(train_generator,
  85.         steps_per_epoch=STEPS_PER_EPOCH_TRAINING,
  86.         epochs = NUM_EPOCHS,
  87.         validation_data=validation_generator,
  88.         validation_steps=STEPS_PER_EPOCH_VALIDATION,
  89.         verbose=2,      
  90.         callbacks = [cb_checkpointer, cb_early_stopper]                          
  91. )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement