Advertisement
Guest User

DenseNet

a guest
May 24th, 2019
171
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.58 KB | None | 0 0
  1. from __future__ import print_function
  2.  
  3. import os.path
  4.  
  5. import densenet
  6. import numpy as np
  7. import sklearn.metrics as metrics
  8.  
  9. #from keras.datasets import cifar10
  10. from keras.utils import np_utils
  11. from keras.preprocessing.image import ImageDataGenerator
  12. from keras.optimizers import Adam
  13. from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  14. from keras import backend as K
  15.  
  16. batch_size = 8
  17. nb_classes = 3
  18. nb_epoch = 10
  19.  
  20.  
  21. from keras.callbacks import TensorBoard
  22. tbCallBack = TensorBoard(log_dir='./log', histogram_freq=1,
  23.                          write_graph=True,
  24.                          write_grads=True,
  25.                          batch_size=batch_size,
  26.                          write_images=True)
  27.  
  28. img_rows, img_cols = 240, 320
  29. img_channels = 3
  30.  
  31. img_dim = (img_channels, img_rows, img_cols) if K.image_dim_ordering() == "th" else (img_rows, img_cols, img_channels)
  32. depth = 40
  33. nb_dense_block = 3
  34. growth_rate = 12
  35. nb_filter = -1
  36. dropout_rate = 0.0 # 0.0 for data augmentation
  37.  
  38. model = densenet.DenseNet(img_dim, classes=nb_classes, depth=depth, nb_dense_block=nb_dense_block,
  39.                           growth_rate=growth_rate, nb_filter=nb_filter, dropout_rate=dropout_rate, weights=None)
  40. print("Model created")
  41.  
  42. model.summary()
  43. optimizer = Adam(lr=1e-3) # Using Adam instead of SGD to speed up training
  44. model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=["accuracy"])
  45. print("Finished compiling")
  46. print("Building model...")
  47.  
  48.  
  49.  
  50. # Load model
  51. weights_file="/content/gdrive/My Drive/PhD/weights/DenseNet-40-12.h5"
  52. if os.path.exists(weights_file):
  53.     #model.load_weights(weights_file, by_name=True)
  54.     print("Model loaded.")
  55.  
  56. out_dir="/content/gdrive/My Drive/PhD/weights/"
  57.  
  58. lr_reducer      = ReduceLROnPlateau(monitor='val_acc', factor=np.sqrt(0.1),
  59.                                     cooldown=0, patience=5, min_lr=1e-5)
  60. model_checkpoint= ModelCheckpoint(weights_file, monitor="val_acc", save_best_only=True,
  61.                                   save_weights_only=True, verbose=1)
  62.  
  63. callbacks=[lr_reducer, model_checkpoint]
  64.  
  65. my_training_batch_generator = My_Generator(TrainPaths, TrainLabels, batch_size)
  66. my_validation_batch_generator = My_Generator(TestPaths, TestLabels, batch_size)
  67.  
  68.  
  69.  
  70. model.fit_generator(generator=my_training_batch_generator,
  71.                     steps_per_epoch=len(TrainPaths) // batch_size, epochs=nb_epoch,
  72.                     callbacks=callbacks,
  73.                     validation_data=my_validation_batch_generator,
  74.                     validation_steps=len(TestPaths) // batch_size,
  75.                     verbose=1)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement