Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python
- # coding: utf-8
- # from keras import backend as k
- # import gc
- # k.clear_session()
- # gc.collect()
- def get_model():
- from keras.applications import VGG16
- from keras.models import Sequential
- from keras.layers import Flatten, Dense
- conv_base = VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
- conv_base.trainable = False
- model = Sequential()
- model.add(conv_base)
- model.add(Flatten())
- model.add(Dense(256, activation='relu'))
- model.add(Dense(1, activation='sigmoid'))
- model.compile(loss='binary_crossentropy', metrics=['acc'], optimizer='rmsprop')
- return model
- def train():
- from keras.preprocessing.image import ImageDataGenerator
- import keras
- # class CustomCallback(keras.callbacks.Callback):
- # def on_epoch_end(self, epoch, logs=None):
- # k.clear_session()
- # gc.collect()
- t_gen = ImageDataGenerator(rescale=1/255, rotation_range=0.2, width_shift_range=0.2, height_shift_range=0.2,
- shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
- v_gen = ImageDataGenerator(rescale=1/255)
- train_dir = r'D:\Inne\Materialy\Programy\scientificProject\data\dogs_vs_cats\train'
- train_gen = t_gen.flow_from_directory(train_dir, target_size=(150, 150), batch_size=20, class_mode='binary')
- val_dir = r'D:\Inne\Materialy\Programy\scientificProject\data\dogs_vs_cats\validation'
- val_gen = v_gen.flow_from_directory(val_dir, target_size=(150, 150), batch_size=20, class_mode='binary')
- model = get_model()
- model.fit(train_gen, steps_per_epoch=100, epochs=30, validation_data=val_gen, validation_steps=50)#,
- # callbacks=[CustomCallback()])
- for layer in model.layers:
- if 'block5_conv' in layer.name:
- layer.trainable = True
- from keras.optimizers import RMSprop
- model.compile(loss='binary_crossentropy', optimizer=RMSprop(lr=1e-5), metrics=['acc'])
- hist = model.fit(train_gen, steps_per_epoch=100, epochs=30, validation_data=val_gen, validation_steps=50,
- callbacks=[CustomCallback()])
- return hist.history
- hist = train()
- # k.clear_session()
- # gc.collect()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement