Advertisement
albert828

dogs_vs_cats_OOM

Nov 9th, 2020
702
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.23 KB | None | 0 0
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3.  
  4. # from keras import backend as k
  5. # import gc
  6.  
  7. # k.clear_session()
  8. # gc.collect()
  9.  
  10.  
  11. def get_model():
  12.     from keras.applications import VGG16
  13.     from keras.models import Sequential
  14.     from keras.layers import Flatten, Dense
  15.  
  16.     conv_base = VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
  17.     conv_base.trainable = False
  18.     model = Sequential()
  19.     model.add(conv_base)
  20.     model.add(Flatten())
  21.     model.add(Dense(256, activation='relu'))
  22.     model.add(Dense(1, activation='sigmoid'))
  23.    
  24.     model.compile(loss='binary_crossentropy', metrics=['acc'], optimizer='rmsprop')
  25.     return model
  26.  
  27. def train():
  28.     from keras.preprocessing.image import ImageDataGenerator
  29.     import keras
  30.    
  31. #     class CustomCallback(keras.callbacks.Callback):
  32. #         def on_epoch_end(self, epoch, logs=None):
  33. #             k.clear_session()
  34. #             gc.collect()
  35.  
  36.     t_gen = ImageDataGenerator(rescale=1/255, rotation_range=0.2, width_shift_range=0.2, height_shift_range=0.2,
  37.                              shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
  38.     v_gen = ImageDataGenerator(rescale=1/255)
  39.     train_dir = r'D:\Inne\Materialy\Programy\scientificProject\data\dogs_vs_cats\train'
  40.     train_gen = t_gen.flow_from_directory(train_dir, target_size=(150, 150), batch_size=20, class_mode='binary')
  41.     val_dir = r'D:\Inne\Materialy\Programy\scientificProject\data\dogs_vs_cats\validation'
  42.     val_gen = v_gen.flow_from_directory(val_dir, target_size=(150, 150), batch_size=20, class_mode='binary')
  43.  
  44.     model = get_model()
  45.     model.fit(train_gen, steps_per_epoch=100, epochs=30, validation_data=val_gen, validation_steps=50)#,
  46. #                         callbacks=[CustomCallback()])
  47.  
  48.     for layer in model.layers:
  49.         if 'block5_conv' in layer.name:
  50.             layer.trainable = True
  51.  
  52.     from keras.optimizers import RMSprop
  53.     model.compile(loss='binary_crossentropy', optimizer=RMSprop(lr=1e-5), metrics=['acc'])
  54.     hist = model.fit(train_gen, steps_per_epoch=100, epochs=30, validation_data=val_gen, validation_steps=50,
  55.                      callbacks=[CustomCallback()])
  56.  
  57.     return hist.history
  58.  
  59. hist = train()
  60.  
  61. # k.clear_session()
  62. # gc.collect()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement