Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def list_rate_schedule(lrates, output=True):
- sched = []
- last_lr = [0]
- for lr, n in lrates:
- sched += [lr] * n
- def lr_sched(epoch):
- lr = sched[-1]
- if epoch < len(sched):
- lr = sched[epoch]
- if output and lr != last_lr[0]:
- print('Learning rate: {}'.format(lr))
- last_lr[0] = lr
- return lr
- return lr_sched
- # usage ----------------------------------------------------
- rates = [
- (1e-5, 2), # 2 epochs @ 1e-5
- (1e-3, 4), # 4 epochs @ 1e-3
- (1e-4, 8), # 8 epochs @ 1e-4
- # ...
- ]
- lrsched = keras.callbacks.LearningRateScheduler(
- list_rate_schedule(rates))
- model.fit_generator(
- trn_batches,
- trn_batches.samples // trn_batches.batch_size,
- callacks=[lrsched],
- # ...
- )
Add Comment
Please, Sign In to add comment