Advertisement
Guest User

Untitled

a guest
Jul 18th, 2019
136
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.08 KB | None | 0 0
  1. import numpy as np
  2.  
  3. from keras.callbacks import Callback
  4. from os import remove
  5. from os.path import join
  6. from glob import glob
  7.  
  8.  
  9. class CustomModelCheckpoint(Callback):
  10. '''
  11. Save the last and best weights as well as the complete model according to the monitored metric.
  12. '''
  13.  
  14. def __init__(self, dirpath, monitor='val_loss', verbose=0,
  15. save_weights_only=False, mode='auto', period=1):
  16.  
  17. super(CustomModelCheckpoint, self).__init__()
  18. self.monitor = monitor
  19. self.verbose = verbose
  20. self.dirpath = dirpath
  21. self.weights_path = join(
  22. dirpath, '{save}_{file}_epoch-{epoch:04d}_loss-{loss:.6f}_val_loss-{val_loss:.6f}.hdf5')
  23. self.save_weights_only = save_weights_only
  24. self.period = period
  25. self.epochs_since_last_save = 0
  26.  
  27. if mode not in ['auto', 'min', 'max']:
  28. warnings.warn('ModelCheckpoint mode %s is unknown, fallback to auto mode.' % (mode),
  29. RuntimeWarning)
  30. mode = 'auto'
  31.  
  32. if mode == 'min':
  33. self.monitor_op = np.less
  34. self.best = np.Inf
  35. elif mode == 'max':
  36. self.monitor_op = np.greater
  37. self.best = -np.Inf
  38. else:
  39. if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
  40. self.monitor_op = np.greater
  41. self.best = -np.Inf
  42. else:
  43. self.monitor_op = np.less
  44. self.best = np.Inf
  45.  
  46. def on_epoch_end(self, epoch, logs=None):
  47. logs = logs or {}
  48. self.epochs_since_last_save += 1
  49.  
  50. if self.epochs_since_last_save >= self.period:
  51. self.epochs_since_last_save = 0
  52.  
  53. current = logs.get(self.monitor)
  54. if current is None:
  55. warnings.warn('Can save best model only with %s available, '
  56. 'skipping.' % (self.monitor), RuntimeWarning)
  57. else:
  58. if self.monitor_op(current, self.best):
  59. if self.verbose > 0:
  60. print('\nEpoch %05d: %s improved from %0.5f to %0.5f' % (
  61. epoch + 1, self.monitor, self.best, current))
  62. self.best = current
  63.  
  64. for ckpt_file in glob(join(self.dirpath, 'best_weights*')):
  65. remove(ckpt_file)
  66. self.model.save_weights(
  67. self.weights_path.format(
  68. save='best', file='weights', epoch=epoch + 1, **logs))
  69.  
  70. if not self.save_weights_only:
  71. for ckpt_file in glob(join(self.dirpath, 'best_model*')):
  72. remove(ckpt_file)
  73. self.model.save(
  74. self.weights_path.format(
  75. save='best', file='model', epoch=epoch + 1, **logs))
  76. else:
  77. if self.verbose > 0:
  78. print('\nEpoch %05d: %s did not improve from %0.5f' %
  79. (epoch + 1, self.monitor, self.best))
  80.  
  81. for ckpt_file in glob(join(self.dirpath, 'last_weights*')):
  82. remove(ckpt_file)
  83. self.model.save_weights(
  84. self.weights_path.format(
  85. save='last', file='weights', epoch=epoch + 1, **logs))
  86. if not self.save_weights_only:
  87. for ckpt_file in glob(join(self.dirpath, 'last_model*')):
  88. remove(ckpt_file)
  89. self.model.save(self.weights_path.format(
  90. save='last', file='model', epoch=epoch + 1, **logs))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement