Advertisement
Guest User

Untitled

a guest
Jul 19th, 2019
164
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.04 KB | None | 0 0
  1. INIT_LR = 5e-3 # initial learning rate
  2. BATCH_SIZE = 32
  3. EPOCHS = 10
  4.  
  5.  
  6. s = reset_tf_session() # clear default graph
  7. # don't call K.set_learning_phase() !!! (otherwise will enable dropout in train/test simultaneously)
  8. model = make_model() # define our model
  9.  
  10.  
  11.  
  12. # Define a Callback class that stops training once accuracy reaches 99.9%
  13. class myCallback(tf.keras.callbacks.Callback):
  14. def on_epoch_end(self, epoch, logs={}):
  15. if(logs.get('acc')>0.999):
  16. print("\nReached 99.9% accuracy so cancelling training!")
  17. self.model.stop_training = True
  18.  
  19.  
  20. # scheduler of learning rate (decay with epochs)
  21. def lr_scheduler(epoch):
  22. return INIT_LR * 0.9 ** epoch
  23.  
  24. # callback for printing of actual learning rate used by optimizer
  25. class LrHistory(keras.callbacks.Callback):
  26. def on_epoch_begin(self, epoch, logs={}):
  27. print("Learning rate:", K.get_value(model.optimizer.lr))
  28.  
  29.  
  30. # prepare model for fitting (loss, optimizer, etc)
  31. model.compile(
  32. loss='categorical_crossentropy', # we train 10-way classification
  33. optimizer=keras.optimizers.adamax(lr=INIT_LR), # for SGD
  34. metrics=['accuracy'] # report accuracy during training
  35. )
  36.  
  37.  
  38. # we will save model checkpoints to continue training in case of kernel death
  39. model_filename = 'cifar.{0:03d}.hdf5'
  40. last_finished_epoch = None
  41.  
  42. #### uncomment below to continue training from model checkpoint
  43. #### fill `last_finished_epoch` with your latest finished epoch
  44. # from keras.models import load_model
  45. # s = reset_tf_session()
  46. # last_finished_epoch = 7
  47. # model = load_model(model_filename.format(last_finished_epoch))
  48.  
  49.  
  50. history = model.fit_generator(
  51. train_generator,
  52. validation_data = validation_generator,
  53. steps_per_epoch = 100,
  54. epochs = 100,
  55. validation_steps = 50,
  56. verbose = 2,
  57. callbacks=[myCallback(),
  58. keras.callbacks.LearningRateScheduler(lr_scheduler),
  59. LrHistory(),
  60. keras_utils.TqdmProgressCallback(),
  61. keras_utils.ModelSaveCallback(model_filename)])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement