SHARE
TWEET

Untitled

a guest Jul 19th, 2019 63 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. INIT_LR = 5e-3  # initial learning rate
  2. BATCH_SIZE = 32
  3. EPOCHS = 10
  4.  
  5.  
  6. s = reset_tf_session()  # clear default graph
  7. # don't call K.set_learning_phase() !!! (otherwise will enable dropout in train/test simultaneously)
  8. model = make_model()  # define our model
  9.  
  10.  
  11.  
  12. # Define a Callback class that stops training once accuracy reaches 99.9%
  13. class myCallback(tf.keras.callbacks.Callback):
  14.   def on_epoch_end(self, epoch, logs={}):
  15.     if(logs.get('acc')>0.999):
  16.       print("\nReached 99.9% accuracy so cancelling training!")
  17.       self.model.stop_training = True
  18.  
  19.  
  20. # scheduler of learning rate (decay with epochs)
  21. def lr_scheduler(epoch):
  22.     return INIT_LR * 0.9 ** epoch
  23.  
  24. # callback for printing of actual learning rate used by optimizer
  25. class LrHistory(keras.callbacks.Callback):
  26.     def on_epoch_begin(self, epoch, logs={}):
  27.         print("Learning rate:", K.get_value(model.optimizer.lr))
  28.  
  29.  
  30. # prepare model for fitting (loss, optimizer, etc)
  31. model.compile(
  32.     loss='categorical_crossentropy',  # we train 10-way classification
  33.     optimizer=keras.optimizers.adamax(lr=INIT_LR),  # for SGD
  34.     metrics=['accuracy']  # report accuracy during training
  35. )
  36.  
  37.  
  38. # we will save model checkpoints to continue training in case of kernel death
  39. model_filename = 'cifar.{0:03d}.hdf5'
  40. last_finished_epoch = None
  41.  
  42. #### uncomment below to continue training from model checkpoint
  43. #### fill `last_finished_epoch` with your latest finished epoch
  44. # from keras.models import load_model
  45. # s = reset_tf_session()
  46. # last_finished_epoch = 7
  47. # model = load_model(model_filename.format(last_finished_epoch))
  48.  
  49.  
  50. history = model.fit_generator(
  51.             train_generator,
  52.             validation_data = validation_generator,
  53.             steps_per_epoch = 100,
  54.             epochs = 100,
  55.             validation_steps = 50,
  56.             verbose = 2,
  57.             callbacks=[myCallback(),
  58.                       keras.callbacks.LearningRateScheduler(lr_scheduler),
  59.                       LrHistory(),
  60.                       keras_utils.TqdmProgressCallback(),
  61.                       keras_utils.ModelSaveCallback(model_filename)])
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top