Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- from keras.callbacks import Callback
- from os import remove
- from os.path import join
- from glob import glob
- class CustomModelCheckpoint(Callback):
- '''
- Save the last and best weights as well as the complete model according to the monitored metric.
- '''
- def __init__(self, dirpath, monitor='val_loss', verbose=0,
- save_weights_only=False, mode='auto', period=1):
- super(CustomModelCheckpoint, self).__init__()
- self.monitor = monitor
- self.verbose = verbose
- self.dirpath = dirpath
- self.weights_path = join(
- dirpath, '{save}_{file}_epoch-{epoch:04d}_loss-{loss:.6f}_val_loss-{val_loss:.6f}.hdf5')
- self.save_weights_only = save_weights_only
- self.period = period
- self.epochs_since_last_save = 0
- if mode not in ['auto', 'min', 'max']:
- warnings.warn('ModelCheckpoint mode %s is unknown, fallback to auto mode.' % (mode),
- RuntimeWarning)
- mode = 'auto'
- if mode == 'min':
- self.monitor_op = np.less
- self.best = np.Inf
- elif mode == 'max':
- self.monitor_op = np.greater
- self.best = -np.Inf
- else:
- if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
- self.monitor_op = np.greater
- self.best = -np.Inf
- else:
- self.monitor_op = np.less
- self.best = np.Inf
- def on_epoch_end(self, epoch, logs=None):
- logs = logs or {}
- self.epochs_since_last_save += 1
- if self.epochs_since_last_save >= self.period:
- self.epochs_since_last_save = 0
- current = logs.get(self.monitor)
- if current is None:
- warnings.warn('Can save best model only with %s available, '
- 'skipping.' % (self.monitor), RuntimeWarning)
- else:
- if self.monitor_op(current, self.best):
- if self.verbose > 0:
- print('\nEpoch %05d: %s improved from %0.5f to %0.5f' % (
- epoch + 1, self.monitor, self.best, current))
- self.best = current
- for ckpt_file in glob(join(self.dirpath, 'best_weights*')):
- remove(ckpt_file)
- self.model.save_weights(
- self.weights_path.format(
- save='best', file='weights', epoch=epoch + 1, **logs))
- if not self.save_weights_only:
- for ckpt_file in glob(join(self.dirpath, 'best_model*')):
- remove(ckpt_file)
- self.model.save(
- self.weights_path.format(
- save='best', file='model', epoch=epoch + 1, **logs))
- else:
- if self.verbose > 0:
- print('\nEpoch %05d: %s did not improve from %0.5f' %
- (epoch + 1, self.monitor, self.best))
- for ckpt_file in glob(join(self.dirpath, 'last_weights*')):
- remove(ckpt_file)
- self.model.save_weights(
- self.weights_path.format(
- save='last', file='weights', epoch=epoch + 1, **logs))
- if not self.save_weights_only:
- for ckpt_file in glob(join(self.dirpath, 'last_model*')):
- remove(ckpt_file)
- self.model.save(self.weights_path.format(
- save='last', file='model', epoch=epoch + 1, **logs))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement